aboutsummaryrefslogtreecommitdiff
path: root/scripts/xy_grid.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/xy_grid.py')
-rw-r--r--scripts/xy_grid.py40
1 files changed, 38 insertions, 2 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 146663b0..044c30e6 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -1,5 +1,6 @@
from collections import namedtuple
from copy import copy
+from itertools import permutations
import random
from PIL import Image
@@ -28,6 +29,27 @@ def apply_prompt(p, x, xs):
p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+def apply_order(p, x, xs):
+ token_order = []
+
+ # Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt
+ for token in x:
+ token_order.append((p.prompt.find(token), token))
+
+ token_order.sort(key=lambda t: t[0])
+
+ search_from_pos = 0
+ for idx, token in enumerate(x):
+ original_pos, old_token = token_order[idx]
+
+ # Get position of the token again as it will likely change as tokens are being replaced
+ pos = p.prompt.find(old_token)
+ if original_pos >= 0:
+ # Avoid trying to replace what was just replaced by searching later in the prompt string
+ p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1)
+
+ search_from_pos = pos + len(token)
+
samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers):
@@ -60,7 +82,8 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
-
+ if type(x) == type(list()):
+ x = str(x)
return x
def do_nothing(p, x, xs):
@@ -89,6 +112,7 @@ axis_options = [
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
+ AxisOption("Prompt order", type(list()), apply_order, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
@@ -159,7 +183,11 @@ class Script(scripts.Script):
if opt.label == 'Nothing':
return [0]
- valslist = [x.strip() for x in vals.split(",")]
+ if opt.type == type(list()):
+ valslist = [x for x in vals]
+ else:
+ valslist = [x.strip() for x in vals.split(",")]
+
if opt.type == int:
valslist_ext = []
@@ -212,9 +240,17 @@ class Script(scripts.Script):
return valslist
x_opt = axis_options[x_type]
+
+ if x_opt.label == "Prompt order":
+ x_values = list(permutations([x.strip() for x in x_values.split(",")]))
+
xs = process_axis(x_opt, x_values)
y_opt = axis_options[y_type]
+
+ if y_opt.label == "Prompt order":
+ y_values = list(permutations([y.strip() for y in y_values.split(",")]))
+
ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list):