import dataclasses import os import gradio as gr from modules import errors, shared @dataclasses.dataclass class PostprocessedImageSharedInfo: target_width: int = None target_height: int = None class PostprocessedImage: def __init__(self, image): self.image = image self.info = {} self.shared = PostprocessedImageSharedInfo() self.extra_images = [] self.nametags = [] self.disable_processing = False self.caption = None def get_suffix(self, used_suffixes=None): used_suffixes = {} if used_suffixes is None else used_suffixes suffix = "-".join(self.nametags) if suffix: suffix = "-" + suffix if suffix not in used_suffixes: used_suffixes[suffix] = 1 return suffix for i in range(1, 100): proposed_suffix = suffix + "-" + str(i) if proposed_suffix not in used_suffixes: used_suffixes[proposed_suffix] = 1 return proposed_suffix return suffix def create_copy(self, new_image, *, nametags=None, disable_processing=False): pp = PostprocessedImage(new_image) pp.shared = self.shared pp.nametags = self.nametags.copy() pp.info = self.info.copy() pp.disable_processing = disable_processing if nametags is not None: pp.nametags += nametags return pp class ScriptPostprocessing: filename = None controls = None args_from = None args_to = None order = 1000 """scripts will be ordred by this value in postprocessing UI""" name = None """this function should return the title of the script.""" group = None """A gr.Group component that has all script's UI inside it""" def ui(self): """ This function should create gradio UI elements. See https://gradio.app/docs/#components The return value should be a dictionary that maps parameter names to components used in processing. Values of those components will be passed to process() function. """ pass def process(self, pp: PostprocessedImage, **args): """ This function is called to postprocess the image. args contains a dictionary with all values returned by components from ui() """ pass def process_firstpass(self, pp: PostprocessedImage, **args): """ Called for all scripts before calling process(). Scripts can examine the image here and set fields of the pp object to communicate things to other scripts. args contains a dictionary with all values returned by components from ui() """ pass def image_changed(self): pass def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: res = func(*args, **kwargs) return res except Exception as e: errors.display(e, f"calling {filename}/{funcname}") return default class ScriptPostprocessingRunner: def __init__(self): self.scripts = None self.ui_created = False def initialize_scripts(self, scripts_data): self.scripts = [] for script_data in scripts_data: script: ScriptPostprocessing = script_data.script_class() script.filename = script_data.path if script.name == "Simple Upscale": continue self.scripts.append(script) def create_script_ui(self, script, inputs): script.args_from = len(inputs) script.args_to = len(inputs) script.controls = wrap_call(script.ui, script.filename, "ui") for control in script.controls.values(): control.custom_script_source = os.path.basename(script.filename) inputs += list(script.controls.values()) script.args_to = len(inputs) def scripts_in_preferred_order(self): if self.scripts is None: import modules.scripts self.initialize_scripts(modules.scripts.postprocessing_scripts_data) scripts_order = shared.opts.postprocessing_operation_order def script_score(name): for i, possible_match in enumerate(scripts_order): if possible_match == name: return i return len(self.scripts) script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)} return sorted(self.scripts, key=lambda x: script_scores[x.name]) def setup_ui(self): inputs = [] for script in self.scripts_in_preferred_order(): with gr.Row() as group: self.create_script_ui(script, inputs) script.group = group self.ui_created = True return inputs def run(self, pp: PostprocessedImage, args): scripts = [] for script in self.scripts_in_preferred_order(): script_args = args[script.args_from:script.args_to] process_args = {} for (name, _component), value in zip(script.controls.items(), script_args): process_args[name] = value scripts.append((script, process_args)) for script, process_args in scripts: script.process_firstpass(pp, **process_args) all_images = [pp] for script, process_args in scripts: if shared.state.skipped: break shared.state.job = script.name for single_image in all_images.copy(): if not single_image.disable_processing: script.process(single_image, **process_args) for extra_image in single_image.extra_images: if not isinstance(extra_image, PostprocessedImage): extra_image = single_image.create_copy(extra_image) all_images.append(extra_image) single_image.extra_images.clear() pp.extra_images = all_images[1:] def create_args_for_run(self, scripts_args): if not self.ui_created: with gr.Blocks(analytics_enabled=False): self.setup_ui() scripts = self.scripts_in_preferred_order() args = [None] * max([x.args_to for x in scripts]) for script in scripts: script_args_dict = scripts_args.get(script.name, None) if script_args_dict is not None: for i, name in enumerate(script.controls): args[script.args_from + i] = script_args_dict.get(name, None) return args def image_changed(self): for script in self.scripts_in_preferred_order(): script.image_changed()