aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-03-11 14:33:38 +0300
committerGitHub <noreply@github.com>2023-03-11 14:33:38 +0300
commitda3f942ab2171e11adf47cd21182db644b9c400a (patch)
tree195ee0c27ce6fdb54146b91eb070f26294443cd7
parentaaa367e35ce4e823219c2954ca141ca1ed14800e (diff)
parenta2d635ad135241a0a40f67f7e1638c9c8a4ded04 (diff)
Merge pull request #8017 from space-nuko/before-process-batch
Add `before_process_batch` script callback
-rw-r--r--modules/processing.py3
-rw-r--r--modules/scripts.py23
2 files changed, 26 insertions, 0 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 0b2f7e60..06e7a440 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ if p.scripts is not None:
+ p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
if len(prompts) == 0:
break
diff --git a/modules/scripts.py b/modules/scripts.py
index 24056a12..e6a505b3 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -80,6 +80,20 @@ class Script:
pass
+ def before_process_batch(self, p, *args, **kwargs):
+ """
+ Called before extra networks are parsed from the prompt, so you can add
+ new extra network keywords to the prompt with this callback.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -388,6 +402,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def before_process_batch(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_process_batch(p, *script_args, **kwargs)
+ except Exception:
+ print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try: