aboutsummaryrefslogtreecommitdiff
path: root/modules/scripts.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/scripts.py')
-rw-r--r--modules/scripts.py170
1 files changed, 122 insertions, 48 deletions
diff --git a/modules/scripts.py b/modules/scripts.py
index 533db45c..b934d881 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -3,11 +3,10 @@ import sys
import traceback
from collections import namedtuple
-import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks, extensions
+from modules import shared, paths, script_callbacks, extensions, script_loading
AlwaysVisible = object()
@@ -18,6 +17,12 @@ class Script:
args_to = None
alwayson = False
+ is_txt2img = False
+ is_img2img = False
+
+ """A gr.Group component that has all script's UI inside it"""
+ group = None
+
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
@@ -70,6 +75,19 @@ class Script:
pass
+ def process_batch(self, p, *args, **kwargs):
+ """
+ Same as process(), but called for every batch.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -78,6 +96,23 @@ class Script:
pass
+ def before_component(self, component, **kwargs):
+ """
+ Called before a component is created.
+ Use elem_id/label fields of kwargs to figure out which component it is.
+ This can be useful to inject your own components somewhere in the middle of vanilla UI.
+ You can return created components in the ui() function to add them to the list of arguments for your processing functions
+ """
+
+ pass
+
+ def after_component(self, component, **kwargs):
+ """
+ Called after a component is created. Same as above.
+ """
+
+ pass
+
def describe(self):
"""unused"""
return ""
@@ -125,7 +160,7 @@ def list_files_with_name(filename):
continue
path = os.path.join(dirpath, filename)
- if os.path.isfile(filename):
+ if os.path.isfile(path):
res.append(path)
return res
@@ -146,13 +181,7 @@ def load_scripts():
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
- with open(scriptfile.path, "r", encoding="utf8") as file:
- text = file.read()
-
- from types import ModuleType
- compiled = compile(text, scriptfile.path, 'exec')
- module = ModuleType(scriptfile.filename)
- exec(compiled, module.__dict__)
+ module = script_loading.load_module(scriptfile.path)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@@ -186,12 +215,18 @@ class ScriptRunner:
self.titles = []
self.infotext_fields = []
- def setup_ui(self, is_img2img):
+ def initialize_scripts(self, is_img2img):
+ self.scripts.clear()
+ self.alwayson_scripts.clear()
+ self.selectable_scripts.clear()
+
for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
+ script.is_txt2img = not is_img2img
+ script.is_img2img = is_img2img
- visibility = script.show(is_img2img)
+ visibility = script.show(script.is_img2img)
if visibility == AlwaysVisible:
self.scripts.append(script)
@@ -202,6 +237,7 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
+ def setup_ui(self):
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
inputs = [None]
@@ -211,15 +247,13 @@ class ScriptRunner:
script.args_from = len(inputs)
script.args_to = len(inputs)
- controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None:
return
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- if not script.alwayson:
- control.visible = False
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
@@ -229,40 +263,41 @@ class ScriptRunner:
script.args_to = len(inputs)
for script in self.alwayson_scripts:
- with gr.Group():
+ with gr.Group() as group:
create_script_ui(script, inputs, inputs_alwayson)
+ script.group = group
+
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
- create_script_ui(script, inputs, inputs_alwayson)
+ with gr.Group(visible=False) as group:
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ script.group = group
def select_script(script_index):
- if 0 < script_index <= len(self.selectable_scripts):
- script = self.selectable_scripts[script_index-1]
- args_from = script.args_from
- args_to = script.args_to
- else:
- args_from = 0
- args_to = 0
+ selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
def init_field(title):
+ """called when an initial value is set from ui-config.json to show script's UI components"""
+
if title == 'None':
return
+
script_index = self.titles.index(title)
- script = self.selectable_scripts[script_index]
- for i in range(script.args_from, script.args_to):
- inputs[i].visible = True
+ self.selectable_scripts[script_index].group.visible = True
dropdown.init_field = init_field
+
dropdown.change(
fn=select_script,
inputs=[dropdown],
- outputs=inputs
+ outputs=[script.group for script in self.selectable_scripts]
)
return inputs
@@ -294,6 +329,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def process_batch(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process_batch(p, *script_args, **kwargs)
+ except Exception:
+ print(f"Error running process_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
try:
@@ -303,33 +347,44 @@ class ScriptRunner:
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def before_component(self, component, **kwargs):
+ for script in self.scripts:
+ try:
+ script.before_component(component, **kwargs)
+ except Exception:
+ print(f"Error running before_component: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def after_component(self, component, **kwargs):
+ for script in self.scripts:
+ try:
+ script.after_component(component, **kwargs)
+ except Exception:
+ print(f"Error running after_component: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
- with open(script.filename, "r", encoding="utf8") as file:
- args_from = script.args_from
- args_to = script.args_to
- filename = script.filename
- text = file.read()
-
- from types import ModuleType
+ args_from = script.args_from
+ args_to = script.args_to
+ filename = script.filename
- module = cache.get(filename, None)
- if module is None:
- compiled = compile(text, filename, 'exec')
- module = ModuleType(script.filename)
- exec(compiled, module.__dict__)
- cache[filename] = module
+ module = cache.get(filename, None)
+ if module is None:
+ module = script_loading.load_module(script.filename)
+ cache[filename] = module
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- self.scripts[si] = script_class()
- self.scripts[si].filename = filename
- self.scripts[si].args_from = args_from
- self.scripts[si].args_to = args_to
+ for key, script_class in module.__dict__.items():
+ if type(script_class) == type and issubclass(script_class, Script):
+ self.scripts[si] = script_class()
+ self.scripts[si].filename = filename
+ self.scripts[si].args_from = args_from
+ self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+scripts_current: ScriptRunner = None
def reload_script_body_only():
@@ -346,3 +401,22 @@ def reload_scripts():
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
+def IOComponent_init(self, *args, **kwargs):
+ if scripts_current is not None:
+ scripts_current.before_component(self, **kwargs)
+
+ script_callbacks.before_component_callback(self, **kwargs)
+
+ res = original_IOComponent_init(self, *args, **kwargs)
+
+ script_callbacks.after_component_callback(self, **kwargs)
+
+ if scripts_current is not None:
+ scripts_current.after_component(self, **kwargs)
+
+ return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+gr.components.IOComponent.__init__ = IOComponent_init