aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/processing.py2
-rw-r--r--modules/scripts.py16
2 files changed, 12 insertions, 6 deletions
diff --git a/modules/processing.py b/modules/processing.py
index e20d8fc4..03c9143d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -502,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
break
if p.scripts is not None:
- p.scripts.process_one(p, n)
+ p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
diff --git a/modules/scripts.py b/modules/scripts.py
index 75e47cd2..366c90d7 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -73,9 +73,15 @@ class Script:
pass
- def process_one(self, p, n, *args):
+ def process_batch(self, p, *args, **kwargs):
"""
- Same as process(), but called for every iteration
+ Same as process(), but called for every batch.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
"""
pass
@@ -303,13 +309,13 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- def process_one(self, p, n):
+ def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
- script.process_one(p, n, *script_args)
+ script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_one: {script.filename}", file=sys.stderr)
+ print(f"Error running process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess(self, p, processed):