aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/outpainting_mk_2.py139
-rw-r--r--scripts/xy_grid.py1
2 files changed, 81 insertions, 59 deletions
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index a6468e09..2afd4aa5 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -172,54 +172,54 @@ class Script(scripts.Script):
if down > 0:
down = target_h - init_img.height - up
- init_image = p.init_images[0]
-
- state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
-
- def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
+ def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
- res_w = init.width + pixels_horiz
- res_h = init.height + pixels_vert
- process_res_w = math.ceil(res_w / 64) * 64
- process_res_h = math.ceil(res_h / 64) * 64
-
- img = Image.new("RGB", (process_res_w, process_res_h))
- img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
- mask = Image.new("RGB", (process_res_w, process_res_h), "white")
- draw = ImageDraw.Draw(mask)
- draw.rectangle((
- expand_pixels + mask_blur if is_left else 0,
- expand_pixels + mask_blur if is_top else 0,
- mask.width - expand_pixels - mask_blur if is_right else res_w,
- mask.height - expand_pixels - mask_blur if is_bottom else res_h,
- ), fill="black")
-
- np_image = (np.asarray(img) / 255.0).astype(np.float64)
- np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
- noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
- out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
-
- target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
- target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
-
- crop_region = (
- 0 if is_left else out.width - target_width,
- 0 if is_top else out.height - target_height,
- target_width if is_left else out.width,
- target_height if is_top else out.height,
- )
-
- image_to_process = out.crop(crop_region)
- mask = mask.crop(crop_region)
-
- p.width = target_width if is_horiz else img.width
- p.height = target_height if is_vert else img.height
- p.init_images = [image_to_process]
- p.image_mask = mask
+ images_to_process = []
+ output_images = []
+ for n in range(count):
+ res_w = init[n].width + pixels_horiz
+ res_h = init[n].height + pixels_vert
+ process_res_w = math.ceil(res_w / 64) * 64
+ process_res_h = math.ceil(res_h / 64) * 64
+
+ img = Image.new("RGB", (process_res_w, process_res_h))
+ img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
+ mask = Image.new("RGB", (process_res_w, process_res_h), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((
+ expand_pixels + mask_blur if is_left else 0,
+ expand_pixels + mask_blur if is_top else 0,
+ mask.width - expand_pixels - mask_blur if is_right else res_w,
+ mask.height - expand_pixels - mask_blur if is_bottom else res_h,
+ ), fill="black")
+
+ np_image = (np.asarray(img) / 255.0).astype(np.float64)
+ np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
+ noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
+ output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
+
+ target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
+ target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
+ p.width = target_width if is_horiz else img.width
+ p.height = target_height if is_vert else img.height
+
+ crop_region = (
+ 0 if is_left else output_images[n].width - target_width,
+ 0 if is_top else output_images[n].height - target_height,
+ target_width if is_left else output_images[n].width,
+ target_height if is_top else output_images[n].height,
+ )
+ mask = mask.crop(crop_region)
+ p.image_mask = mask
+
+ image_to_process = output_images[n].crop(crop_region)
+ images_to_process.append(image_to_process)
+
+ p.init_images = images_to_process
latent_mask = Image.new("RGB", (p.width, p.height), "white")
draw = ImageDraw.Draw(latent_mask)
@@ -232,31 +232,52 @@ class Script(scripts.Script):
p.latent_mask = latent_mask
proc = process_images(p)
- proc_img = proc.images[0]
if initial_seed_and_info[0] is None:
initial_seed_and_info[0] = proc.seed
initial_seed_and_info[1] = proc.info
- out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
- out = out.crop((0, 0, res_w, res_h))
- return out
+ for n in range(count):
+ output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
+ output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
- img = init_image
+ return output_images
- if left > 0:
- img = expand(img, left, is_left=True)
- if right > 0:
- img = expand(img, right, is_right=True)
- if up > 0:
- img = expand(img, up, is_top=True)
- if down > 0:
- img = expand(img, down, is_bottom=True)
+ batch_count = p.n_iter
+ batch_size = p.batch_size
+ p.n_iter = 1
+ state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
+ all_processed_images = []
+
+ for i in range(batch_count):
+ imgs = [init_img] * batch_size
+ state.job = f"Batch {i + 1} out of {batch_count}"
+
+ if left > 0:
+ imgs = expand(imgs, batch_size, left, is_left=True)
+ if right > 0:
+ imgs = expand(imgs, batch_size, right, is_right=True)
+ if up > 0:
+ imgs = expand(imgs, batch_size, up, is_top=True)
+ if down > 0:
+ imgs = expand(imgs, batch_size, down, is_bottom=True)
- res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
+ all_processed_images += imgs
+
+ all_images = all_processed_images
+
+ combined_grid_image = images.image_grid(all_processed_images)
+ unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
+ if opts.return_grid and not unwanted_grid_because_of_img_count:
+ all_images = [combined_grid_image] + all_processed_images
+
+ res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
if opts.samples_save:
- images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+ for img in all_processed_images:
+ images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
- return res
+ if opts.grid_save and not unwanted_grid_because_of_img_count:
+ images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
+ return res
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 5cca168a..eff0c942 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
+ p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs):