aboutsummaryrefslogtreecommitdiff
path: root/scripts/prompt_matrix.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-03 19:32:45 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-03 19:32:45 +0300
commit592334f322d403679a125225afb5ff0114935edd (patch)
tree3f8bc6963488eba5d4a1421e98de6a0f49adc439 /scripts/prompt_matrix.py
parent595c827bd31773cc98eb6e87b11090960a32b2a2 (diff)
scripts
Diffstat (limited to 'scripts/prompt_matrix.py')
-rw-r--r--scripts/prompt_matrix.py82
1 files changed, 82 insertions, 0 deletions
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
new file mode 100644
index 00000000..7087bcde
--- /dev/null
+++ b/scripts/prompt_matrix.py
@@ -0,0 +1,82 @@
+import math
+from collections import namedtuple
+from copy import copy
+import random
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import images
+from modules.processing import process_images, Processed
+from modules.shared import opts, cmd_opts, state
+import modules.sd_samplers
+
+
+def draw_xy_grid(xs, ys, x_label, y_label, cell):
+ res = []
+
+ ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
+ hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
+
+ first_pocessed = None
+
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
+
+ processed = cell(x, y)
+ if first_pocessed is None:
+ first_pocessed = processed
+
+ res.append(processed.images[0])
+
+ grid = images.image_grid(res, rows=len(ys))
+ grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
+
+ first_pocessed.images = [grid]
+
+ return first_pocessed
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Prompt matrix"
+
+ def ui(self, is_img2img):
+ put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False)
+
+ return [put_at_start]
+
+ def run(self, p, put_at_start):
+ seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
+
+ original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
+
+ all_prompts = []
+ prompt_matrix_parts = original_prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts) - 1)
+ for combination_num in range(combination_count):
+ selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
+
+ if put_at_start:
+ selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
+ else:
+ selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
+
+ all_prompts.append(", ".join(selected_prompts))
+
+ p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
+ p.do_not_save_grid = True
+
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
+
+ p.prompt = all_prompts
+ p.prompt_for_display = original_prompt
+ p.seed = len(all_prompts) * [seed]
+ processed = process_images(p)
+
+ grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
+ grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
+ processed.images.insert(0, grid)
+
+ return processed