aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py89
-rw-r--r--modules/api/models.py2
-rw-r--r--modules/generation_parameters_copypaste.py85
-rw-r--r--modules/processing.py4
-rw-r--r--modules/processing_scripts/refiner.py7
-rw-r--r--modules/processing_scripts/seed.py13
-rw-r--r--modules/ui.py46
7 files changed, 187 insertions, 59 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 7154c9d5..2918f785 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -312,8 +312,13 @@ class Api:
script_args[script.args_from:script.args_to] = ui_default_values
return script_args
- def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
+ def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
script_args = default_script_args.copy()
+
+ if input_script_args is not None:
+ for index, value in input_script_args.items():
+ script_args[index] = value
+
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
@@ -335,6 +340,75 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
+ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
+ """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
+
+ If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
+
+ Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
+ """
+
+ if not request.infotext:
+ return {}
+
+ possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"]
+ set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
+ params = generation_parameters_copypaste.parse_generation_parameters(request.infotext)
+
+ def get_field_value(field, params):
+ value = field.function(params) if field.function else params.get(field.label)
+ if value is None:
+ return None
+
+ if field.api in request.__fields__:
+ target_type = request.__fields__[field.api].type_
+ else:
+ target_type = type(field.component.value)
+
+ if target_type == type(None):
+ return None
+
+ if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
+ value = value.get('value')
+
+ if value is not None and not isinstance(value, target_type):
+ value = target_type(value)
+
+ return value
+
+ for field in possible_fields:
+ if not field.api:
+ continue
+
+ if field.api in set_fields:
+ continue
+
+ value = get_field_value(field, params)
+ if value is not None:
+ setattr(request, field.api, value)
+
+ if request.override_settings is None:
+ request.override_settings = {}
+
+ overriden_settings = generation_parameters_copypaste.get_override_settings(params)
+ for _, setting_name, value in overriden_settings:
+ if setting_name not in request.override_settings:
+ request.override_settings[setting_name] = value
+
+ if script_runner is not None and mentioned_script_args is not None:
+ indexes = {v: i for i, v in enumerate(script_runner.inputs)}
+ script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
+
+ for field, index in script_fields:
+ value = get_field_value(field, params)
+
+ if value is None:
+ continue
+
+ mentioned_script_args[index] = value
+
+ return params
+
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
@@ -342,6 +416,10 @@ class Api:
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui()
+
+ infotext_script_args = {}
+ self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
@@ -358,8 +436,9 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
+ args.pop('infotext', None)
- script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
@@ -406,6 +485,10 @@ class Api:
if not script_runner.scripts:
script_runner.initialize_scripts(True)
ui.create_ui()
+
+ infotext_script_args = {}
+ self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
@@ -425,7 +508,7 @@ class Api:
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
- script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
diff --git a/modules/api/models.py b/modules/api/models.py
index 58083a34..16edf11c 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -108,6 +108,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
@@ -126,6 +127,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index dbffe494..86a36c32 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -28,6 +28,19 @@ class ParamBinding:
self.paste_field_names = paste_field_names or []
+class PasteField(tuple):
+ def __new__(cls, component, target, *, api=None):
+ return super().__new__(cls, (component, target))
+
+ def __init__(self, component, target, *, api=None):
+ super().__init__()
+
+ self.api = api
+ self.component = component
+ self.label = target if isinstance(target, str) else None
+ self.function = target if callable(target) else None
+
+
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []
@@ -84,6 +97,12 @@ def image_from_url_text(filedata):
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+
+ if fields:
+ for i in range(len(fields)):
+ if not isinstance(fields[i], PasteField):
+ fields[i] = PasteField(*fields[i])
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions
@@ -371,6 +390,48 @@ def create_override_settings_dict(text_pairs):
return res
+def get_override_settings(params, *, skip_fields=None):
+ """Returns a list of settings overrides from the infotext parameters dictionary.
+
+ This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
+ a list of tuples containing the parameter name, setting name, and new value cast to correct type.
+
+ It checks for conditions before adding an override:
+ - ignores settings that match the current value
+ - ignores parameter keys present in skip_fields argument.
+
+ Example input:
+ {"Clip skip": "2"}
+
+ Example output:
+ [("Clip skip", "CLIP_stop_at_last_layers", 2)]
+ """
+
+ res = []
+
+ mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
+ for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
+ if param_name in (skip_fields or {}):
+ continue
+
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
+ continue
+
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+
+ if v == current_value:
+ continue
+
+ res.append((param_name, setting_name, v))
+
+ return res
+
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
@@ -412,29 +473,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params):
- vals = {}
-
- mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
- for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
- if param_name in already_handled_fields:
- continue
-
- v = params.get(param_name, None)
- if v is None:
- continue
-
- if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
- continue
-
- v = shared.opts.cast_value(setting_name, v)
- current_value = getattr(shared.opts, setting_name, None)
-
- if v == current_value:
- continue
-
- vals[param_name] = v
+ vals = get_override_settings(params, skip_fields=already_handled_fields)
- vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+ vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
diff --git a/modules/processing.py b/modules/processing.py
index 9351e3fb..2f11d5f8 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1135,7 +1135,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if self.hr_checkpoint_name:
+ if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
if self.hr_checkpoint_info is None:
@@ -1482,7 +1482,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
- images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
image = images.flatten(img, opts.img2img_background_color)
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 29ccb78f..cefad32b 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -1,6 +1,7 @@
import gradio as gr
from modules import scripts, sd_models
+from modules.generation_parameters_copypaste import PasteField
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
@@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
return None if info is None else info.title
self.infotext_fields = [
- (enable_refiner, lambda d: 'Refiner' in d),
- (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
- (refiner_switch_at, 'Refiner switch at'),
+ PasteField(enable_refiner, lambda d: 'Refiner' in d),
+ PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
+ PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index dc9c2da5..a3e16a12 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -3,6 +3,7 @@ import json
import gradio as gr
from modules import scripts, ui, errors
+from modules.generation_parameters_copypaste import PasteField
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
@@ -51,12 +52,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
- (self.seed, "Seed"),
- (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
+ PasteField(self.seed, "Seed", api="seed"),
+ PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
+ PasteField(subseed, "Variation seed", api="subseed"),
+ PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
+ PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
+ PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
diff --git a/modules/ui.py b/modules/ui.py
index d80486dd..9db2407e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -28,7 +28,7 @@ import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
from modules import prompt_parser
from modules.sd_hijack import model_hijack
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.generation_parameters_copypaste import image_from_url_text, PasteField
create_setting_component = ui_settings.create_setting_component
@@ -436,28 +436,28 @@ def create_ui():
)
txt2img_paste_fields = [
- (toprow.prompt, "Prompt"),
- (toprow.negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_name, "Sampler"),
- (cfg_scale, "CFG scale"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- (hr_checkpoint_name, "Hires checkpoint"),
- (hr_sampler_name, "Hires sampler"),
- (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
- (hr_prompt, "Hires prompt"),
- (hr_negative_prompt, "Hires negative prompt"),
- (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
+ PasteField(toprow.prompt, "Prompt", api="prompt"),
+ PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
+ PasteField(steps, "Steps", api="steps"),
+ PasteField(sampler_name, "Sampler", api="sampler_name"),
+ PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
+ PasteField(width, "Size-1", api="width"),
+ PasteField(height, "Size-2", api="height"),
+ PasteField(batch_size, "Batch size", api="batch_size"),
+ PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
+ PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
+ PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
+ PasteField(hr_scale, "Hires upscale", api="hr_scale"),
+ PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
+ PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
+ PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
+ PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
+ PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
+ PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
+ PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
+ PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
+ PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
+ PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)