aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-16 16:14:41 +0300
committerGitHub <noreply@github.com>2023-01-16 16:14:41 +0300
commitd073637e104fa7256e0b0c546b5b83a16b881b8a (patch)
treec86b6a74df17351740213137f7284620094f095b
parent064983c0adb00cd9e88d2f06f66c9a1d5bc116c3 (diff)
parent029260b4ca7267d7a75319dbc11bca2a8c52774e (diff)
Merge pull request #6803 from space-nuko/xy-grid-performance-improvement
Optimize XY grid to run slower axes fewer times
-rw-r--r--scripts/xy_grid.py123
1 files changed, 70 insertions, 53 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 01dd3eae..e06c11cb 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -178,76 +178,87 @@ def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
-AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
-AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
+AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])
+AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"])
axis_options = [
- AxisOption("Nothing", str, do_nothing, format_nothing, None),
- AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
- AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
- AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
- AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
- AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
- AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
- AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
- AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
- AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
- AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
- AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
- AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
- AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
- AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
- AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
- AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
- AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
- AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
- AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
- AxisOption("VAE", str, apply_vae, format_value_add_label, None),
- AxisOption("Styles", str, apply_styles, format_value_add_label, None),
+ AxisOption("Nothing", str, do_nothing, format_nothing, None, 0),
+ AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0),
+ AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0),
+ AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0),
+ AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0),
+ AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0),
+ AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0),
+ AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0),
+ AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2),
+ AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0),
+ AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0),
+ AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0),
+ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0),
+ AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0),
+ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0),
+ AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0),
+ AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0),
+ AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0),
+ AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0),
+ AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7),
+ AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0),
]
-def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
+def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
# Temporary list of all the images that are generated to be populated into the grid.
# Will be filled with empty images for any individual step that fails to process properly
- image_cache = []
+ image_cache = [None] * (len(xs) * len(ys))
processed_result = None
cell_mode = "P"
- cell_size = (1,1)
+ cell_size = (1, 1)
state.job_count = len(xs) * len(ys) * p.n_iter
- for iy, y in enumerate(ys):
+ def process_cell(x, y, ix, iy):
+ nonlocal image_cache, processed_result, cell_mode, cell_size
+
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
+
+ processed: Processed = cell(x, y)
+
+ try:
+ # this dereference will throw an exception if the image was not processed
+ # (this happens in cases such as if the user stops the process from the UI)
+ processed_image = processed.images[0]
+
+ if processed_result is None:
+ # Use our first valid processed result as a template container to hold our full results
+ processed_result = copy(processed)
+ cell_mode = processed_image.mode
+ cell_size = processed_image.size
+ processed_result.images = [Image.new(cell_mode, cell_size)]
+
+ image_cache[ix + iy * len(xs)] = processed_image
+ if include_lone_images:
+ processed_result.images.append(processed_image)
+ processed_result.all_prompts.append(processed.prompt)
+ processed_result.all_seeds.append(processed.seed)
+ processed_result.infotexts.append(processed.infotexts[0])
+ except:
+ image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
+
+ if swap_axes_processing_order:
for ix, x in enumerate(xs):
- state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
-
- processed:Processed = cell(x, y)
- try:
- # this dereference will throw an exception if the image was not processed
- # (this happens in cases such as if the user stops the process from the UI)
- processed_image = processed.images[0]
-
- if processed_result is None:
- # Use our first valid processed result as a template container to hold our full results
- processed_result = copy(processed)
- cell_mode = processed_image.mode
- cell_size = processed_image.size
- processed_result.images = [Image.new(cell_mode, cell_size)]
-
- image_cache.append(processed_image)
- if include_lone_images:
- processed_result.images.append(processed_image)
- processed_result.all_prompts.append(processed.prompt)
- processed_result.all_seeds.append(processed.seed)
- processed_result.infotexts.append(processed.infotexts[0])
- except:
- image_cache.append(Image.new(cell_mode, cell_size))
+ for iy, y in enumerate(ys):
+ process_cell(x, y, ix, iy)
+ else:
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, ix, iy)
if not processed_result:
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
@@ -417,6 +428,11 @@ class Script(scripts.Script):
grid_infotext = [None]
+ # If one of the axes is very slow to change between (like SD model
+ # checkpoint), then make sure it is in the outer iteration of the nested
+ # `for` loop.
+ swap_axes_processing_order = x_opt.cost > y_opt.cost
+
def cell(x, y):
if shared.state.interrupted:
return Processed(p, [], p.seed, "")
@@ -455,7 +471,8 @@ class Script(scripts.Script):
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell,
draw_legend=draw_legend,
- include_lone_images=include_lone_images
+ include_lone_images=include_lone_images,
+ swap_axes_processing_order=swap_axes_processing_order
)
if opts.grid_save: