aboutsummaryrefslogtreecommitdiff
path: root/modules/scripts.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-08 16:50:23 +0300
committerGitHub <noreply@github.com>2023-07-08 16:50:23 +0300
commit993dd9a8927407de8d19142cacb07e6f76686a67 (patch)
tree89df31c33ecf054c9cb70aaf38fc3382f306f450 /modules/scripts.py
parentff6acd35d0807a4e0c3ee86cdb1520a4a3a11cdd (diff)
parentd7d6e8cfc8b85a99a48f82975ee213d487783c28 (diff)
Merge branch 'dev' into patch-1
Diffstat (limited to 'modules/scripts.py')
-rw-r--r--modules/scripts.py225
1 files changed, 152 insertions, 73 deletions
diff --git a/modules/scripts.py b/modules/scripts.py
index d945b89f..7d9dd59f 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,12 +1,12 @@
import os
import re
import sys
-import traceback
+import inspect
from collections import namedtuple
import gradio as gr
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
AlwaysVisible = object()
@@ -17,6 +17,12 @@ class PostprocessImageArgs:
class Script:
+ name = None
+ """script's internal name derived from title"""
+
+ section = None
+ """name of UI section that the script's controls will be placed into"""
+
filename = None
args_from = None
args_to = None
@@ -25,8 +31,8 @@ class Script:
is_txt2img = False
is_img2img = False
- """A gr.Group component that has all script's UI inside it"""
group = None
+ """A gr.Group component that has all script's UI inside it"""
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -38,6 +44,9 @@ class Script:
various "Send to <X>" buttons when clicked
"""
+ api_info = None
+ """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -76,6 +85,15 @@ class Script:
pass
+ def before_process(self, p, *args):
+ """
+ This function is called very early before processing begins for AlwaysVisible scripts.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
@@ -99,6 +117,21 @@ class Script:
pass
+ def after_extra_networks_activate(self, p, *args, **kwargs):
+ """
+ Calledafter extra networks activation, before conds calculation
+ allow modification of the network after extra networks activation been applied
+ won't be call if p.disable_extra_networks
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ - extra_network_data - list of ExtraNetworkParams for current stage
+ """
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -169,6 +202,11 @@ class Script:
return f'script_{tabname}{title}_{item_id}'
+ def before_hr(self, p, *args):
+ """
+ This function is called before hires fix start.
+ """
+ pass
current_basedir = paths.script_path
@@ -231,8 +269,8 @@ def load_scripts():
syspath = sys.path
def register_scripts_from_module(module):
- for key, script_class in module.__dict__.items():
- if type(script_class) != type:
+ for script_class in module.__dict__.values():
+ if not inspect.isclass(script_class):
continue
if issubclass(script_class, Script):
@@ -258,21 +296,25 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
current_basedir = paths.script_path
+ timer.startup_timer.record(scriptfile.filename)
+
+ global scripts_txt2img, scripts_img2img, scripts_postproc
+
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
- res = func(*args, **kwargs)
- return res
+ return func(*args, **kwargs)
except Exception:
- print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -285,6 +327,7 @@ class ScriptRunner:
self.titles = []
self.infotext_fields = []
self.paste_field_names = []
+ self.inputs = [None]
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -295,9 +338,9 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
- for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
- script = script_class()
- script.filename = path
+ for script_data in auto_processing_scripts + scripts_data:
+ script = script_data.script_class()
+ script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -312,49 +355,74 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
- def setup_ui(self):
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+ def create_script_ui(self, script):
+ import modules.api.models as api_models
- inputs = [None]
- inputs_alwayson = [True]
+ script.args_from = len(self.inputs)
+ script.args_to = len(self.inputs)
- def create_script_ui(script, inputs, inputs_alwayson):
- script.args_from = len(inputs)
- script.args_to = len(inputs)
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
- controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
+ if controls is None:
+ return
- if controls is None:
- return
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+ api_args = []
- for control in controls:
- control.custom_script_source = os.path.basename(script.filename)
+ for control in controls:
+ control.custom_script_source = os.path.basename(script.filename)
- if script.infotext_fields is not None:
- self.infotext_fields += script.infotext_fields
+ arg_info = api_models.ScriptArg(label=control.label or "")
- if script.paste_field_names is not None:
- self.paste_field_names += script.paste_field_names
+ for field in ("value", "minimum", "maximum", "step", "choices"):
+ v = getattr(control, field, None)
+ if v is not None:
+ setattr(arg_info, field, v)
- inputs += controls
- inputs_alwayson += [script.alwayson for _ in controls]
- script.args_to = len(inputs)
+ api_args.append(arg_info)
- for script in self.alwayson_scripts:
- with gr.Group() as group:
- create_script_ui(script, inputs, inputs_alwayson)
+ script.api_info = api_models.ScriptInfo(
+ name=script.name,
+ is_img2img=script.is_img2img,
+ is_alwayson=script.alwayson,
+ args=api_args,
+ )
- script.group = group
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
- dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
- inputs[0] = dropdown
+ if script.paste_field_names is not None:
+ self.paste_field_names += script.paste_field_names
+
+ self.inputs += controls
+ script.args_to = len(self.inputs)
+
+ def setup_ui_for_section(self, section, scriptlist=None):
+ if scriptlist is None:
+ scriptlist = self.alwayson_scripts
+
+ for script in scriptlist:
+ if script.alwayson and script.section != section:
+ continue
- for script in self.selectable_scripts:
- with gr.Group(visible=False) as group:
- create_script_ui(script, inputs, inputs_alwayson)
+ with gr.Group(visible=script.alwayson) as group:
+ self.create_script_ui(script)
script.group = group
+ def prepare_ui(self):
+ self.inputs = [None]
+
+ def setup_ui(self):
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+
+ self.setup_ui_for_section(None)
+
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
+ self.inputs[0] = dropdown
+
+ self.setup_ui_for_section(None, self.selectable_scripts)
+
def select_script(script_index):
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
@@ -378,6 +446,7 @@ class ScriptRunner:
)
self.script_load_ctr = 0
+
def onload_script_visibility(params):
title = params.get('Script', None)
if title:
@@ -388,10 +457,10 @@ class ScriptRunner:
else:
return gr.update(visible=False)
- self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
- self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
+ self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
+ self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
- return inputs
+ return self.inputs
def run(self, p, *args):
script_index = args[0]
@@ -411,14 +480,21 @@ class ScriptRunner:
return processed
+ def before_process(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_process(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_process: {script.filename}", exc_info=True)
+
def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running process: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -426,8 +502,15 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
+
+ def after_extra_networks_activate(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.after_extra_networks_activate(p, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -435,8 +518,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -444,8 +526,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print(f"Error running postprocess: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -453,8 +534,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -462,24 +542,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print(f"Error running before_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print(f"Error running after_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
@@ -492,7 +569,7 @@ class ScriptRunner:
module = script_loading.load_module(script.filename)
cache[filename] = module
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
@@ -500,9 +577,18 @@ class ScriptRunner:
self.scripts[si].args_to = args_to
-scripts_txt2img = ScriptRunner()
-scripts_img2img = ScriptRunner()
-scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+ def before_hr(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_hr(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+
+
+scripts_txt2img: ScriptRunner = None
+scripts_img2img: ScriptRunner = None
+scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
scripts_current: ScriptRunner = None
@@ -512,14 +598,7 @@ def reload_script_body_only():
scripts_img2img.reload_sources(cache)
-def reload_scripts():
- global scripts_txt2img, scripts_img2img, scripts_postproc
-
- load_scripts()
-
- scripts_txt2img = ScriptRunner()
- scripts_img2img = ScriptRunner()
- scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+reload_scripts = load_scripts # compatibility alias
def add_classes_to_gradio_component(comp):