aboutsummaryrefslogtreecommitdiff
path: root/scripts/xy_grid.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/xy_grid.py')
-rw-r--r--scripts/xy_grid.py35
1 files changed, 18 insertions, 17 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 7def47f5..1237e754 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -29,10 +29,11 @@ def apply_prompt(p, x, xs):
p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+
def apply_order(p, x, xs):
token_order = []
- # Initally grab the tokens from the prompt so they can be be replaced in order of earliest seen
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
for token in x:
token_order.append((p.prompt.find(token), token))
@@ -85,17 +86,26 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
- if type(x) == type(list()):
- x = str(x)
return x
+
+def format_value_join_list(p, opt, x):
+ return ", ".join(x)
+
+
def do_nothing(p, x, xs):
pass
+
def format_nothing(p, opt, x):
return ""
+def str_permutations(x):
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
+ return x
+
+
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
@@ -108,6 +118,7 @@ axis_options = [
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
@@ -115,7 +126,6 @@ axis_options = [
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
- AxisOption("Prompt order", type(list()), apply_order, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
@@ -158,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
+
class Script(scripts.Script):
def title(self):
return "X/Y plot"
@@ -186,11 +197,7 @@ class Script(scripts.Script):
if opt.label == 'Nothing':
return [0]
- if opt.type == type(list()):
- valslist = [x for x in vals]
- else:
- valslist = [x.strip() for x in vals.split(",")]
-
+ valslist = [x.strip() for x in vals.split(",")]
if opt.type == int:
valslist_ext = []
@@ -237,23 +244,17 @@ class Script(scripts.Script):
valslist_ext.append(val)
valslist = valslist_ext
+ elif opt.type == str_permutations:
+ valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist]
return valslist
x_opt = axis_options[x_type]
-
- if x_opt.label == "Prompt order":
- x_values = list(permutations([x.strip() for x in x_values.split(",")]))
-
xs = process_axis(x_opt, x_values)
y_opt = axis_options[y_type]
-
- if y_opt.label == "Prompt order":
- y_values = list(permutations([y.strip() for y in y_values.split(",")]))
-
ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list):