aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorljleb <set>2023-07-24 13:52:24 -0400
committerljleb <set>2023-07-24 13:52:24 -0400
commitca45ff1ae6fdd5c2dcd754fde95dd29f49bd414b (patch)
tree2a16bb1b9e7985b98c4c8f991b602c856e80a87c /modules/processing.py
parentf451994053140622ef5e394bc02ac166fb74e56f (diff)
add postprocess_batch_list callback
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/modules/processing.py b/modules/processing.py
index a74a5302..c16404f4 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -717,7 +717,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
+ all_prompts = p.all_prompts[:]
+ all_seeds = p.all_seeds[:]
+ all_subseeds = p.all_subseeds[:]
+
+ # apply changes to generation data
+ all_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.prompts
+ all_seeds[n * p.batch_size:(n + 1) * p.batch_size] = p.seeds
+ all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] = p.subseeds
+
+ # update p.all_negative_prompts in case extensions changed the size of the batch
+ # create_infotext below uses it
+ old_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.negative_prompts
+
+ try:
+ return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
+ finally:
+ # restore p.all_negative_prompts in case extensions changed the size of the batch
+ p.all_negative_prompts[n * p.batch_size:n * p.batch_size + len(p.negative_prompts)] = old_negative_prompts
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -806,6 +824,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
+ postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
+ p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
+ x_samples_ddim = postprocess_batch_list_args.images
+
for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i