aboutsummaryrefslogtreecommitdiff
path: root/modules/scripts.py
diff options
context:
space:
mode:
authorJaredTherriault <noirjt@live.com>2023-09-04 17:29:33 -0700
committerGitHub <noreply@github.com>2023-09-04 17:29:33 -0700
commit5e16914a4e157ab3ed96f8b7841e1290a56f4484 (patch)
tree655f4582e692f0fc3667b3b668ad365ac3ab92ae /modules/scripts.py
parent8f3b02f09535f55d3673aa9ea589396b8614f799 (diff)
parent5ef669de080814067961f28357256e8fe27544f4 (diff)
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/scripts.py')
-rw-r--r--modules/scripts.py193
1 files changed, 136 insertions, 57 deletions
diff --git a/modules/scripts.py b/modules/scripts.py
index 5b4edcac..e8518ad0 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -3,6 +3,7 @@ import re
import sys
import inspect
from collections import namedtuple
+from dataclasses import dataclass
import gradio as gr
@@ -21,6 +22,11 @@ class PostprocessBatchListArgs:
self.images = images
+@dataclass
+class OnComponent:
+ component: gr.blocks.Block
+
+
class Script:
name = None
"""script's internal name derived from title"""
@@ -35,9 +41,13 @@ class Script:
is_txt2img = False
is_img2img = False
+ tabname = None
group = None
- """A gr.Group component that has all script's UI inside it"""
+ """A gr.Group component that has all script's UI inside it."""
+
+ create_group = True
+ """If False, for alwayson scripts, a group component will not be created."""
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -52,6 +62,15 @@ class Script:
api_info = None
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
+ on_before_component_elem_id = None
+ """list of callbacks to be called before a component with an elem_id is created"""
+
+ on_after_component_elem_id = None
+ """list of callbacks to be called after a component with an elem_id is created"""
+
+ setup_for_ui_only = False
+ """If true, the script setup will only be run in Gradio UI, not in API"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -90,9 +109,16 @@ class Script:
pass
+ def setup(self, p, *args):
+ """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
+ args contains all values returned by components from ui().
+ """
+ pass
+
+
def before_process(self, p, *args):
"""
- This function is called very early before processing begins for AlwaysVisible scripts.
+ This function is called very early during processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
args contains all values returned by components from ui()
"""
@@ -212,6 +238,29 @@ class Script:
pass
+ def on_before_component(self, callback, *, elem_id):
+ """
+ Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
+
+ May be called in show() or ui() - but it may be too late in latter as some components may already be created.
+
+ This function is an alternative to before_component in that it also cllows to run before a component is created, but
+ it doesn't require to be called for every created component - just for the one you need.
+ """
+ if self.on_before_component_elem_id is None:
+ self.on_before_component_elem_id = []
+
+ self.on_before_component_elem_id.append((elem_id, callback))
+
+ def on_after_component(self, callback, *, elem_id):
+ """
+ Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
+ """
+ if self.on_after_component_elem_id is None:
+ self.on_after_component_elem_id = []
+
+ self.on_after_component_elem_id.append((elem_id, callback))
+
def describe(self):
"""unused"""
return ""
@@ -220,7 +269,7 @@ class Script:
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
- tabkind = 'img2img' if self.is_img2img else 'txt2txt'
+ tabkind = 'img2img' if self.is_img2img else 'txt2img'
tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
@@ -232,6 +281,19 @@ class Script:
"""
pass
+
+class ScriptBuiltinUI(Script):
+ setup_for_ui_only = True
+
+ def elem_id(self, item_id):
+ """helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
+
+ need_tabname = self.show(True) == self.show(False)
+ tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
+
+ return f'{tabname}{item_id}'
+
+
current_basedir = paths.script_path
@@ -250,7 +312,7 @@ postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
-def list_scripts(scriptdirname, extension):
+def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = []
basedir = os.path.join(paths.script_path, scriptdirname)
@@ -258,8 +320,9 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
- for ext in extensions.active():
- scripts_list += ext.list_files(scriptdirname, extension)
+ if include_extensions:
+ for ext in extensions.active():
+ scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@@ -288,7 +351,7 @@ def load_scripts():
postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks()
- scripts_list = list_scripts("scripts", ".py")
+ scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
syspath = sys.path
@@ -349,10 +412,17 @@ class ScriptRunner:
self.selectable_scripts = []
self.alwayson_scripts = []
self.titles = []
+ self.title_map = {}
self.infotext_fields = []
self.paste_field_names = []
self.inputs = [None]
+ self.on_before_component_elem_id = {}
+ """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
+
+ self.on_after_component_elem_id = {}
+ """dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
+
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -367,6 +437,7 @@ class ScriptRunner:
script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
+ script.tabname = "img2img" if is_img2img else "txt2img"
visibility = script.show(script.is_img2img)
@@ -379,6 +450,28 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
+ self.apply_on_before_component_callbacks()
+
+ def apply_on_before_component_callbacks(self):
+ for script in self.scripts:
+ on_before = script.on_before_component_elem_id or []
+ on_after = script.on_after_component_elem_id or []
+
+ for elem_id, callback in on_before:
+ if elem_id not in self.on_before_component_elem_id:
+ self.on_before_component_elem_id[elem_id] = []
+
+ self.on_before_component_elem_id[elem_id].append((callback, script))
+
+ for elem_id, callback in on_after:
+ if elem_id not in self.on_after_component_elem_id:
+ self.on_after_component_elem_id[elem_id] = []
+
+ self.on_after_component_elem_id[elem_id].append((callback, script))
+
+ on_before.clear()
+ on_after.clear()
+
def create_script_ui(self, script):
import modules.api.models as api_models
@@ -429,15 +522,20 @@ class ScriptRunner:
if script.alwayson and script.section != section:
continue
- with gr.Group(visible=script.alwayson) as group:
- self.create_script_ui(script)
+ if script.create_group:
+ with gr.Group(visible=script.alwayson) as group:
+ self.create_script_ui(script)
- script.group = group
+ script.group = group
+ else:
+ self.create_script_ui(script)
def prepare_ui(self):
self.inputs = [None]
def setup_ui(self):
+ all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
+ self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
self.setup_ui_for_section(None)
@@ -484,6 +582,8 @@ class ScriptRunner:
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
+ self.apply_on_before_component_callbacks()
+
return self.inputs
def run(self, p, *args):
@@ -577,6 +677,12 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
+ for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
+ try:
+ callback(OnComponent(component=component))
+ except Exception:
+ errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
+
for script in self.scripts:
try:
script.before_component(component, **kwargs)
@@ -584,12 +690,21 @@ class ScriptRunner:
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
+ for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
+ try:
+ callback(OnComponent(component=component))
+ except Exception:
+ errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
+
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
+ def script(self, title):
+ return self.title_map.get(title.lower())
+
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
args_from = script.args_from
@@ -608,7 +723,6 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
-
def before_hr(self, p):
for script in self.alwayson_scripts:
try:
@@ -617,6 +731,17 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+ def setup_scrips(self, p, *, is_ui=True):
+ for script in self.alwayson_scripts:
+ if not is_ui and script.setup_for_ui_only:
+ continue
+
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.setup(p, *script_args)
+ except Exception:
+ errors.report(f"Error running setup: {script.filename}", exc_info=True)
+
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
@@ -631,49 +756,3 @@ def reload_script_body_only():
reload_scripts = load_scripts # compatibility alias
-
-
-def add_classes_to_gradio_component(comp):
- """
- this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
- """
-
- comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
-
- if getattr(comp, 'multiselect', False):
- comp.elem_classes.append('multiselect')
-
-
-
-def IOComponent_init(self, *args, **kwargs):
- if scripts_current is not None:
- scripts_current.before_component(self, **kwargs)
-
- script_callbacks.before_component_callback(self, **kwargs)
-
- res = original_IOComponent_init(self, *args, **kwargs)
-
- add_classes_to_gradio_component(self)
-
- script_callbacks.after_component_callback(self, **kwargs)
-
- if scripts_current is not None:
- scripts_current.after_component(self, **kwargs)
-
- return res
-
-
-original_IOComponent_init = gr.components.IOComponent.__init__
-gr.components.IOComponent.__init__ = IOComponent_init
-
-
-def BlockContext_init(self, *args, **kwargs):
- res = original_BlockContext_init(self, *args, **kwargs)
-
- add_classes_to_gradio_component(self)
-
- return res
-
-
-original_BlockContext_init = gr.blocks.BlockContext.__init__
-gr.blocks.BlockContext.__init__ = BlockContext_init