aboutsummaryrefslogtreecommitdiff
path: root/scripts/outpainting_mk_2.py
diff options
context:
space:
mode:
authorwywywywy <wywywywy@gmail.com>2022-10-20 16:02:32 +0100
committerGitHub <noreply@github.com>2022-10-20 16:02:32 +0100
commit91efe138b35dda65e83070c14e9eb94f481fe476 (patch)
tree36ab0706d3af58b0b0ad1c9f013305e9efbe5e42 /scripts/outpainting_mk_2.py
parent4281f255d5e7c67515d619f53654be59a6fc1e13 (diff)
Implemented batch_size logic in outpainting_mk2
Diffstat (limited to 'scripts/outpainting_mk_2.py')
-rw-r--r--scripts/outpainting_mk_2.py118
1 files changed, 63 insertions, 55 deletions
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index 02e655e9..0377ab32 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -176,50 +176,53 @@ class Script(scripts.Script):
state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
- def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
+ def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
- res_w = init.width + pixels_horiz
- res_h = init.height + pixels_vert
- process_res_w = math.ceil(res_w / 64) * 64
- process_res_h = math.ceil(res_h / 64) * 64
-
- img = Image.new("RGB", (process_res_w, process_res_h))
- img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
- mask = Image.new("RGB", (process_res_w, process_res_h), "white")
- draw = ImageDraw.Draw(mask)
- draw.rectangle((
- expand_pixels + mask_blur if is_left else 0,
- expand_pixels + mask_blur if is_top else 0,
- mask.width - expand_pixels - mask_blur if is_right else res_w,
- mask.height - expand_pixels - mask_blur if is_bottom else res_h,
- ), fill="black")
-
- np_image = (np.asarray(img) / 255.0).astype(np.float64)
- np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
- noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
- out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
-
- target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
- target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
-
- crop_region = (
- 0 if is_left else out.width - target_width,
- 0 if is_top else out.height - target_height,
- target_width if is_left else out.width,
- target_height if is_top else out.height,
- )
-
- image_to_process = out.crop(crop_region)
- mask = mask.crop(crop_region)
-
- p.width = target_width if is_horiz else img.width
- p.height = target_height if is_vert else img.height
- p.init_images = [image_to_process]
- p.image_mask = mask
+ images_to_process = []
+ for n in range(count):
+ res_w = init[n].width + pixels_horiz
+ res_h = init[n].height + pixels_vert
+ process_res_w = math.ceil(res_w / 64) * 64
+ process_res_h = math.ceil(res_h / 64) * 64
+
+ img = Image.new("RGB", (process_res_w, process_res_h))
+ img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
+ mask = Image.new("RGB", (process_res_w, process_res_h), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((
+ expand_pixels + mask_blur if is_left else 0,
+ expand_pixels + mask_blur if is_top else 0,
+ mask.width - expand_pixels - mask_blur if is_right else res_w,
+ mask.height - expand_pixels - mask_blur if is_bottom else res_h,
+ ), fill="black")
+
+ np_image = (np.asarray(img) / 255.0).astype(np.float64)
+ np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
+ noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
+ out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
+
+ target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
+ target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
+ p.width = target_width if is_horiz else img.width
+ p.height = target_height if is_vert else img.height
+
+ crop_region = (
+ 0 if is_left else out.width - target_width,
+ 0 if is_top else out.height - target_height,
+ target_width if is_left else out.width,
+ target_height if is_top else out.height,
+ )
+ mask = mask.crop(crop_region)
+ p.image_mask = mask
+
+ image_to_process = out.crop(crop_region)
+ images_to_process.append(image_to_process)
+
+ p.init_images = images_to_process
latent_mask = Image.new("RGB", (p.width, p.height), "white")
draw = ImageDraw.Draw(latent_mask)
@@ -232,44 +235,49 @@ class Script(scripts.Script):
p.latent_mask = latent_mask
proc = process_images(p)
- proc_img = proc.images[0]
if initial_seed_and_info[0] is None:
initial_seed_and_info[0] = proc.seed
initial_seed_and_info[1] = proc.info
- out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
- out = out.crop((0, 0, res_w, res_h))
- return out
+ for proc_img in proc.images:
+ out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
+ out = out.crop((0, 0, res_w, res_h))
+
+ return proc.images
batch_count = p.n_iter
+ batch_size = p.batch_size
p.n_iter = 1
state.job_count = batch_count
- all_images = []
+ all_processed_images = []
for i in range(batch_count):
- img = init_image
- state.job = f"Batch {i + 1} out of {state.job_count}"
+ imgs = [init_img] * batch_size
+ state.job = f"Batch {i + 1} out of {batch_count}"
if left > 0:
- img = expand(img, left, is_left=True)
+ imgs = expand(imgs, batch_size, left, is_left=True)
if right > 0:
- img = expand(img, right, is_right=True)
+ imgs = expand(imgs, batch_size, right, is_right=True)
if up > 0:
- img = expand(img, up, is_top=True)
+ imgs = expand(imgs, batch_size, up, is_top=True)
if down > 0:
- img = expand(img, down, is_bottom=True)
+ imgs = expand(imgs, batch_size, down, is_bottom=True)
- all_images.append(img)
+ all_processed_images += imgs
+
+ combined_grid_image = images.image_grid(all_processed_images)
+ all_images = all_processed_images
- combined_grid_image = images.image_grid(all_images)
if opts.return_grid:
- all_images = [combined_grid_image] + all_images
-
+ all_images = [combined_grid_image] + all_processed_images
+
res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
if opts.samples_save:
- images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+ for img in all_processed_images:
+ images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
if opts.grid_save:
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)