aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-01 17:01:06 +0300
committerGitHub <noreply@github.com>2024-01-01 17:01:06 +0300
commit7ba02e0b7cfc85d5d237eba71ab4d66564857d55 (patch)
treeb317227a0e63c42aa8c4b06147761dfd37ae24fc /scripts
parentbe31e7e71a08dc27543d31aa6e6532463ccbf20f (diff)
parent15156cde18844f459ba101b1356d162aa7f39c47 (diff)
Merge branch 'dev' into finer-settings-freezing-control
Diffstat (limited to 'scripts')
-rw-r--r--scripts/loopback.py6
-rw-r--r--scripts/postprocessing_caption.py30
-rw-r--r--scripts/postprocessing_codeformer.py16
-rw-r--r--scripts/postprocessing_create_flipped_copies.py32
-rw-r--r--scripts/postprocessing_focal_crop.py54
-rw-r--r--scripts/postprocessing_gfpgan.py13
-rw-r--r--scripts/postprocessing_split_oversized.py71
-rw-r--r--scripts/postprocessing_upscale.py12
-rw-r--r--scripts/processing_autosized_crop.py64
-rw-r--r--scripts/prompts_from_file.py17
-rw-r--r--scripts/xyz_grid.py28
11 files changed, 318 insertions, 25 deletions
diff --git a/scripts/loopback.py b/scripts/loopback.py
index 2d5feaf9..800ee882 100644
--- a/scripts/loopback.py
+++ b/scripts/loopback.py
@@ -95,7 +95,7 @@ class Script(scripts.Script):
processed = processing.process_images(p)
# Generation cancelled.
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
if initial_seed is None:
@@ -122,8 +122,8 @@ class Script(scripts.Script):
p.inpainting_fill = original_inpainting_fill
- if state.interrupted:
- break
+ if state.interrupted or state.stopping_generation:
+ break
if len(history) > 1:
grid = images.image_grid(history, rows=1)
diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py
new file mode 100644
index 00000000..5592a898
--- /dev/null
+++ b/scripts/postprocessing_caption.py
@@ -0,0 +1,30 @@
+from modules import scripts_postprocessing, ui_components, deepbooru, shared
+import gradio as gr
+
+
+class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
+ name = "Caption"
+ order = 4040
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Caption") as enable:
+ option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False)
+
+ return {
+ "enable": enable,
+ "option": option,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
+ if not enable:
+ return
+
+ captions = [pp.caption]
+
+ if "Deepbooru" in option:
+ captions.append(deepbooru.model.tag(pp.image))
+
+ if "BLIP" in option:
+ captions.append(shared.interrogator.interrogate(pp.image.convert("RGB")))
+
+ pp.caption = ", ".join([x for x in captions if x])
diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py
index a7d80d40..e1e156dd 100644
--- a/scripts/postprocessing_codeformer.py
+++ b/scripts/postprocessing_codeformer.py
@@ -1,28 +1,28 @@
from PIL import Image
import numpy as np
-from modules import scripts_postprocessing, codeformer_model
+from modules import scripts_postprocessing, codeformer_model, ui_components
import gradio as gr
-from modules.ui_components import FormRow
-
class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
name = "CodeFormer"
order = 3000
def ui(self):
- with FormRow():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
+ with ui_components.InputAccordion(False, label="CodeFormer") as enable:
+ with gr.Row():
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
return {
+ "enable": enable,
"codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight,
}
- def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
- if codeformer_visibility == 0:
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight):
+ if codeformer_visibility == 0 or not enable:
return
restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py
new file mode 100644
index 00000000..b673003b
--- /dev/null
+++ b/scripts/postprocessing_create_flipped_copies.py
@@ -0,0 +1,32 @@
+from PIL import ImageOps, Image
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):
+ name = "Create flipped copies"
+ order = 4030
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Create flipped copies") as enable:
+ with gr.Row():
+ option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False)
+
+ return {
+ "enable": enable,
+ "option": option,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
+ if not enable:
+ return
+
+ if "Horizontal" in option:
+ pp.extra_images.append(ImageOps.mirror(pp.image))
+
+ if "Vertical" in option:
+ pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM))
+
+ if "Both" in option:
+ pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT))
diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py
new file mode 100644
index 00000000..cff1dbc5
--- /dev/null
+++ b/scripts/postprocessing_focal_crop.py
@@ -0,0 +1,54 @@
+
+from modules import scripts_postprocessing, ui_components, errors
+import gradio as gr
+
+from modules.textual_inversion import autocrop
+
+
+class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
+ name = "Auto focal point crop"
+ order = 4010
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:
+ face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight")
+ entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight")
+ edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight")
+ debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
+
+ return {
+ "enable": enable,
+ "face_weight": face_weight,
+ "entropy_weight": entropy_weight,
+ "edges_weight": edges_weight,
+ "debug": debug,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):
+ if not enable:
+ return
+
+ if not pp.shared.target_width or not pp.shared.target_height:
+ return
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models()
+ except Exception:
+ errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True)
+
+ autocrop_settings = autocrop.Settings(
+ crop_width=pp.shared.target_width,
+ crop_height=pp.shared.target_height,
+ face_points_weight=face_weight,
+ entropy_points_weight=entropy_weight,
+ corner_points_weight=edges_weight,
+ annotate_image=debug,
+ dnn_model_path=dnn_model_path,
+ )
+
+ result, *others = autocrop.crop_image(pp.image, autocrop_settings)
+
+ pp.image = result
+ pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others]
+
diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py
index d854f3f7..6e756605 100644
--- a/scripts/postprocessing_gfpgan.py
+++ b/scripts/postprocessing_gfpgan.py
@@ -1,26 +1,25 @@
from PIL import Image
import numpy as np
-from modules import scripts_postprocessing, gfpgan_model
+from modules import scripts_postprocessing, gfpgan_model, ui_components
import gradio as gr
-from modules.ui_components import FormRow
-
class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
name = "GFPGAN"
order = 2000
def ui(self):
- with FormRow():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
+ with ui_components.InputAccordion(False, label="GFPGAN") as enable:
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility")
return {
+ "enable": enable,
"gfpgan_visibility": gfpgan_visibility,
}
- def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
- if gfpgan_visibility == 0:
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility):
+ if gfpgan_visibility == 0 or not enable:
return
restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
diff --git a/scripts/postprocessing_split_oversized.py b/scripts/postprocessing_split_oversized.py
new file mode 100644
index 00000000..c4a03160
--- /dev/null
+++ b/scripts/postprocessing_split_oversized.py
@@ -0,0 +1,71 @@
+import math
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
+
+class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing):
+ name = "Split oversized images"
+ order = 4000
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Split oversized images") as enable:
+ with gr.Row():
+ split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold")
+ overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio")
+
+ return {
+ "enable": enable,
+ "split_threshold": split_threshold,
+ "overlap_ratio": overlap_ratio,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio):
+ if not enable:
+ return
+
+ width = pp.shared.target_width
+ height = pp.shared.target_height
+
+ if not width or not height:
+ return
+
+ if pp.image.height > pp.image.width:
+ ratio = (pp.image.width * height) / (pp.image.height * width)
+ inverse_xy = False
+ else:
+ ratio = (pp.image.height * width) / (pp.image.width * height)
+ inverse_xy = True
+
+ if ratio >= 1.0 and ratio > split_threshold:
+ return
+
+ result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio)
+
+ pp.image = result
+ pp.extra_images = [pp.create_copy(x) for x in others]
+
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index eb42a29e..ed709688 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
return image
+ def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
+ if upscale_mode == 1:
+ pp.shared.target_width = upscale_to_width
+ pp.shared.target_height = upscale_to_height
+ else:
+ pp.shared.target_width = int(pp.image.width * upscale_by)
+ pp.shared.target_height = int(pp.image.height * upscale_by)
+
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if upscaler_1_name == "None":
upscaler_1_name = None
@@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
"upscaler_name": upscaler_name,
}
+ def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
+ pp.shared.target_width = int(pp.image.width * upscale_by)
+ pp.shared.target_height = int(pp.image.height * upscale_by)
+
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
if upscaler_name is None or upscaler_name == "None":
return
diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py
new file mode 100644
index 00000000..7e674989
--- /dev/null
+++ b/scripts/processing_autosized_crop.py
@@ -0,0 +1,64 @@
+from PIL import Image
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+def center_crop(image: Image, w: int, h: int):
+ iw, ih = image.size
+ if ih / h < iw / w:
+ sw = w * ih / h
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+ else:
+ sh = h * iw / w
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
+
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+ iw, ih = image.size
+ err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h))
+ wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1],
+ default=None
+ )
+ return wh and center_crop(image, *wh)
+
+
+class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):
+ name = "Auto-sized crop"
+ order = 4020
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Auto-sized crop") as enable:
+ gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
+ with gr.Row():
+ mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim")
+ maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim")
+ with gr.Row():
+ minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea")
+ maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea")
+ with gr.Row():
+ objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective")
+ threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold")
+
+ return {
+ "enable": enable,
+ "mindim": mindim,
+ "maxdim": maxdim,
+ "minarea": minarea,
+ "maxarea": maxarea,
+ "objective": objective,
+ "threshold": threshold,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold):
+ if not enable:
+ return
+
+ cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold)
+ if cropped is not None:
+ pp.image = cropped
+ else:
+ print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)")
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index ca73b2a5..a4a2f24d 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -114,6 +114,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
+ prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
@@ -124,9 +125,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
- return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
+ return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
- def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
+ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
p.do_not_save_grid = True
@@ -167,6 +168,18 @@ class Script(scripts.Script):
else:
setattr(copy_p, k, v)
+ if args.get("prompt") and p.prompt:
+ if prompt_position == "start":
+ copy_p.prompt = args.get("prompt") + " " + p.prompt
+ else:
+ copy_p.prompt = p.prompt + " " + args.get("prompt")
+
+ if args.get("negative_prompt") and p.negative_prompt:
+ if prompt_position == "start":
+ copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt
+ else:
+ copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt")
+
proc = process_images(copy_p)
images += proc.images
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 0dc255bc..2f385ebf 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -270,6 +270,7 @@ axis_options = [
AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
+ AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
]
@@ -437,13 +438,16 @@ class Script(scripts.Script):
with gr.Column():
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ with gr.Row():
+ vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.")
+ vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.")
+ vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.")
with gr.Column():
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
+ csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
with gr.Column():
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
- with gr.Column():
- csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
@@ -475,6 +479,8 @@ class Script(scripts.Script):
fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
+ axis_type = axis_type or 0 # if axle type is None set to 0
+
choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
@@ -522,9 +528,11 @@ class Script(scripts.Script):
(z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),
)
- return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode]
+ return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode]
+
+ def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode):
+ x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0
- def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
@@ -688,7 +696,7 @@ class Script(scripts.Script):
grid_infotext = [None] * (1 + len(zs))
def cell(x, y, z, ix, iy, iz):
- if shared.state.interrupted:
+ if shared.state.interrupted or state.stopping_generation:
return Processed(p, [], p.seed, "")
pc = copy(p)
@@ -697,6 +705,16 @@ class Script(scripts.Script):
y_opt.apply(pc, y, ys)
z_opt.apply(pc, z, zs)
+ xdim = len(xs) if vary_seeds_x else 1
+ ydim = len(ys) if vary_seeds_y else 1
+
+ if vary_seeds_x:
+ pc.seed += ix
+ if vary_seeds_y:
+ pc.seed += iy * xdim
+ if vary_seeds_z:
+ pc.seed += iz * xdim * ydim
+
try:
res = process_images(pc)
except Exception as e: