aboutsummaryrefslogtreecommitdiff
path: root/modules/scripts.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/scripts.py')
-rw-r--r--modules/scripts.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/modules/scripts.py b/modules/scripts.py
index b934d881..23ca195d 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -88,6 +88,17 @@ class Script:
pass
+ def postprocess_batch(self, p, *args, **kwargs):
+ """
+ Same as process_batch(), but called for every batch after it has been generated.
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -347,6 +358,15 @@ class ScriptRunner:
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def postprocess_batch(self, p, images, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def before_component(self, component, **kwargs):
for script in self.scripts:
try: