aboutsummaryrefslogtreecommitdiff
path: root/scripts/xyz_grid.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/xyz_grid.py')
-rw-r--r--scripts/xyz_grid.py27
1 files changed, 14 insertions, 13 deletions
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 1010845e..d37b428f 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -3,6 +3,7 @@ from copy import copy
from itertools import permutations, chain
import random
import csv
+import os.path
from io import StringIO
from PIL import Image
import numpy as np
@@ -10,7 +11,7 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
+from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
@@ -66,14 +67,6 @@ def apply_order(p, x, xs):
p.prompt = prompt_tmp + p.prompt
-def apply_sampler(p, x, xs):
- sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
- if sampler_name is None:
- raise RuntimeError(f"Unknown sampler: {x}")
-
- p.sampler_name = sampler_name
-
-
def confirm_samplers(p, xs):
for x in xs:
if x.lower() not in sd_samplers.samplers_map:
@@ -182,6 +175,8 @@ def do_nothing(p, x, xs):
def format_nothing(p, opt, x):
return ""
+def format_remove_path(p, opt, x):
+ return os.path.basename(x)
def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
@@ -221,9 +216,10 @@ axis_options = [
AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
- AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
- AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
+ AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+ AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+ AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")),
@@ -648,7 +644,12 @@ class Script(scripts.Script):
y_opt.apply(pc, y, ys)
z_opt.apply(pc, z, zs)
- res = process_images(pc)
+ try:
+ res = process_images(pc)
+ except Exception as e:
+ errors.display(e, "generating image for xyz plot")
+
+ res = Processed(p, [], p.seed, "")
# Sets subgrid infotexts
subgrid_index = 1 + iz