aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorInvincibleDude <81354513+InvincibleDude@users.noreply.github.com>2023-01-30 15:35:13 +0300
committerGitHub <noreply@github.com>2023-01-30 15:35:13 +0300
commit3ec2eb8bf12ae629c292ed0e96f199669040c5de (patch)
treefb46cb76c06f4c6a5ad4ad2ce8cd3a4577525be5
parent0d834b9394bb1a9dbcbdc02a3d4d24d1e6511073 (diff)
parentee9fdf7f62984dc30770fb1a73e68736b319746f (diff)
Merge branch 'master' into improved-hr-conflict-test
-rw-r--r--configs/instruct-pix2pix.yaml3
-rw-r--r--javascript/ui.js36
-rw-r--r--launch.py3
-rw-r--r--modules/generation_parameters_copypaste.py206
-rw-r--r--modules/img2img.py6
-rw-r--r--modules/processing.py1
-rw-r--r--modules/sd_samplers.py519
-rw-r--r--modules/sd_samplers_common.py78
-rw-r--r--modules/sd_samplers_compvis.py160
-rw-r--r--modules/sd_samplers_kdiffusion.py298
-rw-r--r--modules/shared.py41
-rw-r--r--modules/txt2img.py8
-rw-r--r--modules/ui.py44
-rw-r--r--modules/ui_common.py6
-rw-r--r--modules/ui_extra_networks_checkpoints.py3
-rw-r--r--webui.py10
16 files changed, 811 insertions, 611 deletions
diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml
index 437ddcef..4e896879 100644
--- a/configs/instruct-pix2pix.yaml
+++ b/configs/instruct-pix2pix.yaml
@@ -20,8 +20,7 @@ model:
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
- use_ema: true
- load_ema: true
+ use_ema: false
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
diff --git a/javascript/ui.js b/javascript/ui.js
index dd40e62d..b7a8268a 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -191,6 +191,28 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt]
}
+
+promptTokecountUpdateFuncs = {}
+
+function recalculatePromptTokens(name){
+ if(promptTokecountUpdateFuncs[name]){
+ promptTokecountUpdateFuncs[name]()
+ }
+}
+
+function recalculate_prompts_txt2img(){
+ recalculatePromptTokens('txt2img_prompt')
+ recalculatePromptTokens('txt2img_neg_prompt')
+ return args_to_array(arguments);
+}
+
+function recalculate_prompts_img2img(){
+ recalculatePromptTokens('img2img_prompt')
+ recalculatePromptTokens('img2img_neg_prompt')
+ return args_to_array(arguments);
+}
+
+
opts = {}
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
@@ -232,14 +254,12 @@ onUiUpdate(function(){
return
}
-
prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative"
- textarea.addEventListener("input", function(){
- update_token_counter(id_button);
- });
+ promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
+ textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
}
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
@@ -273,7 +293,7 @@ onOptionsChanged(function(){
let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
-let token_timeout;
+let token_timeouts = {};
function update_txt2img_tokens(...args) {
update_token_counter("txt2img_token_button")
@@ -290,9 +310,9 @@ function update_img2img_tokens(...args) {
}
function update_token_counter(button_id) {
- if (token_timeout)
- clearTimeout(token_timeout);
- token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
+ if (token_timeouts[button_id])
+ clearTimeout(token_timeouts[button_id]);
+ token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
function restart_reload(){
diff --git a/launch.py b/launch.py
index 370920de..25909469 100644
--- a/launch.py
+++ b/launch.py
@@ -223,6 +223,7 @@ def prepare_environment():
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
@@ -282,7 +283,7 @@ def prepare_environment():
if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
- run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 3f224453..147eace2 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,4 +1,5 @@
import base64
+import html
import io
import math
import os
@@ -16,13 +17,23 @@ re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update())
+
paste_fields = {}
-bind_list = []
+registered_param_bindings = []
+
+
+class ParamBinding:
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
+ self.paste_button = paste_button
+ self.tabname = tabname
+ self.source_text_component = source_text_component
+ self.source_image_component = source_image_component
+ self.source_tabname = source_tabname
+ self.override_settings_component = override_settings_component
def reset():
paste_fields.clear()
- bind_list.clear()
def quote(text):
@@ -74,26 +85,6 @@ def add_paste_fields(tabname, init_img, fields):
modules.ui.img2img_paste_fields = fields
-def integrate_settings_paste_fields(component_dict):
- from modules import ui
-
- settings_map = {
- 'CLIP_stop_at_last_layers': 'Clip skip',
- 'inpainting_mask_weight': 'Conditional mask weight',
- 'sd_model_checkpoint': 'Model hash',
- 'eta_noise_seed_delta': 'ENSD',
- 'initial_noise_multiplier': 'Noise multiplier',
- }
- settings_paste_fields = [
- (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
- for k, v in settings_map.items()
- ]
-
- for tabname, info in paste_fields.items():
- if info["fields"] is not None:
- info["fields"] += settings_paste_fields
-
-
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
@@ -101,9 +92,60 @@ def create_buttons(tabs_list):
return buttons
-#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info):
- bind_list.append([buttons, send_image, send_generate_info])
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
+ for tabname, button in buttons.items():
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
+
+ register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
+
+
+def register_paste_params_button(binding: ParamBinding):
+ registered_param_bindings.append(binding)
+
+
+def connect_paste_params_buttons():
+ binding: ParamBinding
+ for binding in registered_param_bindings:
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
+ fields = paste_fields[binding.tabname]["fields"]
+
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if binding.source_image_component and destination_image_component:
+ if isinstance(binding.source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
+ else:
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+
+ binding.paste_button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[binding.source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ )
+
+ if binding.source_text_component is not None and fields is not None:
+ connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
+
+ if binding.source_tabname is not None and fields is not None:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
+ binding.paste_button.click(
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
+ )
+
+ binding.paste_button.click(
+ fn=None,
+ _js=f"switch_to_{binding.tabname}",
+ inputs=None,
+ outputs=None,
+ )
def send_image_and_dimensions(x):
@@ -122,49 +164,6 @@ def send_image_and_dimensions(x):
return img, w, h
-def run_bind():
- for buttons, source_image_component, send_generate_info in bind_list:
- for tab in buttons:
- button = buttons[tab]
- destination_image_component = paste_fields[tab]["init_img"]
- fields = paste_fields[tab]["fields"]
-
- destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
- destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
-
- if source_image_component and destination_image_component:
- if isinstance(source_image_component, gr.Gallery):
- func = send_image_and_dimensions if destination_width_component else image_from_url_text
- jsfunc = "extract_image_from_gallery"
- else:
- func = send_image_and_dimensions if destination_width_component else lambda x: x
- jsfunc = None
-
- button.click(
- fn=func,
- _js=jsfunc,
- inputs=[source_image_component],
- outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
- )
-
- if send_generate_info and fields is not None:
- if send_generate_info in paste_fields:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
- button.click(
- fn=lambda *x: x,
- inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
- outputs=[field for field, name in fields if name in paste_field_names],
- )
- else:
- connect_paste(button, fields, send_generate_info)
-
- button.click(
- fn=None,
- _js=f"switch_to_{tab}",
- inputs=None,
- outputs=None,
- )
-
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
@@ -290,7 +289,50 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-def connect_paste(button, paste_fields, input_comp, jsfunc=None):
+settings_map = {}
+
+infotext_to_setting_name_mapping = [
+ ('Clip skip', 'CLIP_stop_at_last_layers', ),
+ ('Conditional mask weight', 'inpainting_mask_weight'),
+ ('Model hash', 'sd_model_checkpoint'),
+ ('ENSD', 'eta_noise_seed_delta'),
+ ('Noise multiplier', 'initial_noise_multiplier'),
+ ('Eta', 'eta_ancestral'),
+ ('Eta DDIM', 'eta_ddim'),
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
+]
+
+
+def create_override_settings_dict(text_pairs):
+ """creates processing's override_settings parameters from gradio's multiselect
+
+ Example input:
+ ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
+
+ Example output:
+ {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
+ """
+
+ res = {}
+
+ params = {}
+ for pair in text_pairs:
+ k, v = pair.split(":", maxsplit=1)
+
+ params[k] = v.strip()
+
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ value = params.get(param_name, None)
+
+ if value is None:
+ continue
+
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
+
+ return res
+
+
+def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(data_path, "params.txt")
@@ -327,9 +369,35 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
return res
+ if override_settings_component is not None:
+ def paste_settings(params):
+ vals = {}
+
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
+ continue
+
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+
+ if v == current_value:
+ continue
+
+ vals[param_name] = v
+
+ vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
+
+ paste_fields = paste_fields + [(override_settings_component, paste_settings)]
+
button.click(
fn=paste_func,
- _js=jsfunc,
+ _js=f"recalculate_prompts_{tabname}",
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
diff --git a/modules/img2img.py b/modules/img2img.py
index 3ecb6146..f813299c 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -7,6 +7,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from modules import devices, sd_samplers
+from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@@ -75,7 +76,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+ override_settings = create_override_settings_dict(override_settings_texts)
+
is_batch = mode == 5
if mode == 0: # img2img
@@ -142,6 +145,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
inpaint_full_res=inpaint_full_res,
inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
+ override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
diff --git a/modules/processing.py b/modules/processing.py
index a2a91a5b..49b1b4ea 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -455,7 +455,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
- "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index a7910b56..28c2136f 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,53 +1,11 @@
-from collections import namedtuple, deque
-import numpy as np
-from math import floor
-import torch
-import tqdm
-from PIL import Image
-import inspect
-import k_diffusion.sampling
-import torchsde._brownian.brownian_interval
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing, images, sd_vae_approx
+from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
-from modules.shared import opts, cmd_opts, state
-import modules.shared as shared
-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-
-
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
-
-samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
- ('Euler', 'sample_euler', ['k_euler'], {}),
- ('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
- ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
- ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
- ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
-]
-
-samplers_data_k_diffusion = [
- SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
- for label, funcname, aliases, options in samplers_k_diffusion
- if hasattr(k_diffusion.sampling, funcname)
-]
+# imports for functions that previously were here and are used by other modules
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
all_samplers = [
- *samplers_data_k_diffusion,
- SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
- SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+ *sd_samplers_kdiffusion.samplers_data_k_diffusion,
+ *sd_samplers_compvis.samplers_data_compvis,
]
all_samplers_map = {x.name: x for x in all_samplers}
@@ -73,8 +31,8 @@ def create_sampler(name, model):
def set_samplers():
global samplers, samplers_for_img2img
- hidden = set(opts.hide_samplers)
- hidden_img2img = set(opts.hide_samplers + ['PLMS'])
+ hidden = set(shared.opts.hide_samplers)
+ hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
@@ -87,466 +45,3 @@ def set_samplers():
set_samplers()
-
-sampler_extra_params = {
- 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
-}
-
-
-def setup_img2img_steps(p, steps=None):
- if opts.img2img_fix_steps or steps is not None:
- requested_steps = (steps or p.steps)
- steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
- t_enc = requested_steps - 1
- else:
- steps = p.steps
- t_enc = int(min(p.denoising_strength, 0.999) * steps)
-
- return steps, t_enc
-
-
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
-
-
-def single_sample_to_image(sample, approximation=None):
- if approximation is None:
- approximation = approximation_indexes.get(opts.show_progress_type, 0)
-
- if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
- elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
- else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
-
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- return Image.fromarray(x_sample)
-
-
-def sample_to_image(samples, index=0, approximation=None):
- return single_sample_to_image(samples[index], approximation)
-
-
-def samples_to_image_grid(samples, approximation=None):
- return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
-
-
-def store_latent(decoded):
- state.current_latent = decoded
-
- if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
- if not shared.parallel_processing_allowed:
- shared.state.assign_current_image(sample_to_image(decoded))
-
-
-class InterruptedException(BaseException):
- pass
-
-
-class VanillaStableDiffusionSampler:
- def __init__(self, constructor, sd_model):
- self.sampler = constructor(sd_model)
- self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.sampler_noises = None
- self.step = 0
- self.stop_at = None
- self.eta = None
- self.default_eta = 0.0
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def number_of_needed_noises(self, p):
- return 0
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except InterruptedException:
- return self.last_latent
-
- def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
- if state.interrupted or state.skipped:
- raise InterruptedException
-
- if self.stop_at is not None and self.step > self.stop_at:
- raise InterruptedException
-
- # Have to unwrap the inpainting conditioning here to perform pre-processing
- image_conditioning = None
- if isinstance(cond, dict):
- image_conditioning = cond["c_concat"][0]
- cond = cond["c_crossattn"][0]
- unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
-
- assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
- cond = tensor
-
- # for DDIM, shapes must match, we can't just process cond and uncond independently;
- # filling unconditional_conditioning with repeats of the last vector to match length is
- # not 100% correct but should work well enough
- if unconditional_conditioning.shape[1] < cond.shape[1]:
- last_vector = unconditional_conditioning[:, -1:]
- last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
- unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
- elif unconditional_conditioning.shape[1] > cond.shape[1]:
- unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
-
- if self.mask is not None:
- img_orig = self.sampler.model.q_sample(self.init_latent, ts)
- x_dec = img_orig * self.mask + self.nmask * x_dec
-
- # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
- # Note that they need to be lists because it just concatenates them later.
- if image_conditioning is not None:
- cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
-
- if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
- else:
- self.last_latent = res[1]
-
- store_latent(self.last_latent)
-
- self.step += 1
- state.sampling_step = self.step
- shared.total_tqdm.update()
-
- return res
-
- def initialize(self, p):
- self.eta = p.eta if p.eta is not None else opts.eta_ddim
-
- for fieldname in ['p_sample_ddim', 'p_sample_plms']:
- if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
-
- self.mask = p.mask if hasattr(p, 'mask') else None
- self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
- def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
- valid_step = 999 / (1000 // num_steps)
- if valid_step == floor(valid_step):
- return int(valid_step) + 1
-
- return num_steps
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = setup_img2img_steps(p, steps)
- steps = self.adjust_steps_if_invalid(p, steps)
- self.initialize(p)
-
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
-
- self.init_latent = x
- self.last_latent = x
- self.step = 0
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- if image_conditioning is not None:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- self.initialize(p)
-
- self.init_latent = None
- self.last_latent = x
- self.step = 0
-
- steps = self.adjust_steps_if_invalid(p, steps or p.steps)
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
- if image_conditioning is not None:
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
-
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
-
- return samples_ddim
-
-
-class CFGDenoiser(torch.nn.Module):
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.step = 0
-
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
-
- return denoised
-
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
- if state.interrupted or state.skipped:
- raise InterruptedException
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
-
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
-
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
-
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
-
- if tensor.shape[1] == uncond.shape[1]:
- cond_in = torch.cat([tensor, uncond])
-
- if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
-
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
-
- devices.test_for_nans(x_out, "unet")
-
- if opts.live_preview_content == "Prompt":
- store_latent(x_out[0:uncond.shape[0]])
- elif opts.live_preview_content == "Negative prompt":
- store_latent(x_out[-uncond.shape[0]:])
-
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
-
- if self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
-
- self.step += 1
-
- return denoised
-
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- if x.device.type == 'mps':
- return torch.randn_like(x, device=devices.cpu).to(x.device)
- else:
- return torch.randn_like(x)
-
-
-# MPS fix for randn in torchsde
-def torchsde_randn(size, dtype, device, seed):
- if device.type == 'mps':
- generator = torch.Generator(devices.cpu).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
- else:
- generator = torch.Generator(device).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=device, generator=generator)
-
-
-torchsde._brownian.brownian_interval._randn = torchsde_randn
-
-
-class KDiffusionSampler:
- def __init__(self, funcname, sd_model):
- denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
-
- self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
- self.funcname = funcname
- self.func = getattr(k_diffusion.sampling, self.funcname)
- self.extra_params = sampler_extra_params.get(funcname, [])
- self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
- self.sampler_noises = None
- self.stop_at = None
- self.eta = None
- self.default_eta = 1.0
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def callback_state(self, d):
- step = d['i']
- latent = d["denoised"]
- if opts.live_preview_content == "Combined":
- store_latent(latent)
- self.last_latent = latent
-
- if self.stop_at is not None and step > self.stop_at:
- raise InterruptedException
-
- state.sampling_step = step
- shared.total_tqdm.update()
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except InterruptedException:
- return self.last_latent
-
- def number_of_needed_noises(self, p):
- return p.steps
-
- def initialize(self, p):
- self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
- self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.step = 0
- self.eta = p.eta or opts.eta_ancestral
-
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
-
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
-
- if 'eta' in inspect.signature(self.func).parameters:
- extra_params_kwargs['eta'] = self.eta
-
- return extra_params_kwargs
-
- def get_sigmas(self, p, steps):
- discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
- if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
- discard_next_to_last_sigma = True
- p.extra_generation_params["Discard penultimate sigma"] = True
-
- steps += 1 if discard_next_to_last_sigma else 0
-
- if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
- elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
-
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
- else:
- sigmas = self.model_wrap.get_sigmas(steps)
-
- if discard_next_to_last_sigma:
- sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
-
- return sigmas
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = setup_img2img_steps(p, steps)
-
- sigmas = self.get_sigmas(p, steps)
-
- sigma_sched = sigmas[steps - t_enc - 1:]
- xi = x + noise * sigma_sched[0]
-
- extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
- ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
- extra_params_kwargs['sigma_min'] = sigma_sched[-2]
- if 'sigma_max' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigma_max'] = sigma_sched[0]
- if 'n' in inspect.signature(self.func).parameters:
- extra_params_kwargs['n'] = len(sigma_sched) - 1
- if 'sigma_sched' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigma_sched'] = sigma_sched
- if 'sigmas' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigmas'] = sigma_sched
-
- self.model_wrap_cfg.init_latent = x
- self.last_latent = x
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
- steps = steps or p.steps
-
- sigmas = self.get_sigmas(p, steps)
-
- x = x * sigmas[0]
-
- extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
- extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in inspect.signature(self.func).parameters:
- extra_params_kwargs['n'] = steps
- else:
- extra_params_kwargs['sigmas'] = sigmas
-
- self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
-
- return samples
-
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
new file mode 100644
index 00000000..3c03d442
--- /dev/null
+++ b/modules/sd_samplers_common.py
@@ -0,0 +1,78 @@
+from collections import namedtuple
+import numpy as np
+import torch
+from PIL import Image
+import torchsde._brownian.brownian_interval
+from modules import devices, processing, images, sd_vae_approx
+
+from modules.shared import opts, state
+import modules.shared as shared
+
+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+def setup_img2img_steps(p, steps=None):
+ if opts.img2img_fix_steps or steps is not None:
+ requested_steps = (steps or p.steps)
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
+ t_enc = requested_steps - 1
+ else:
+ steps = p.steps
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
+
+ return steps, t_enc
+
+
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ return Image.fromarray(x_sample)
+
+
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
+
+
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+
+
+def store_latent(decoded):
+ state.current_latent = decoded
+
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if not shared.parallel_processing_allowed:
+ shared.state.assign_current_image(sample_to_image(decoded))
+
+
+class InterruptedException(BaseException):
+ pass
+
+
+# MPS fix for randn in torchsde
+# XXX move this to separate file for MPS
+def torchsde_randn(size, dtype, device, seed):
+ if device.type == 'mps':
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ else:
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+torchsde._brownian.brownian_interval._randn = torchsde_randn
+
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
new file mode 100644
index 00000000..d03131cd
--- /dev/null
+++ b/modules/sd_samplers_compvis.py
@@ -0,0 +1,160 @@
+import math
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+
+import numpy as np
+import torch
+
+from modules.shared import state
+from modules import sd_samplers_common, prompt_parser, shared
+
+
+samplers_data_compvis = [
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+]
+
+
+class VanillaStableDiffusionSampler:
+ def __init__(self, constructor, sd_model):
+ self.sampler = constructor(sd_model)
+ self.is_plms = hasattr(self.sampler, 'p_sample_plms')
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.sampler_noises = None
+ self.step = 0
+ self.stop_at = None
+ self.eta = None
+ self.config = None
+ self.last_latent = None
+
+ self.conditioning_key = sd_model.model.conditioning_key
+
+ def number_of_needed_noises(self, p):
+ return 0
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except sd_samplers_common.InterruptedException:
+ return self.last_latent
+
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ if self.stop_at is not None and self.step > self.stop_at:
+ raise sd_samplers_common.InterruptedException
+
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
+ image_conditioning = None
+ if isinstance(cond, dict):
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
+ cond = tensor
+
+ # for DDIM, shapes must match, we can't just process cond and uncond independently;
+ # filling unconditional_conditioning with repeats of the last vector to match length is
+ # not 100% correct but should work well enough
+ if unconditional_conditioning.shape[1] < cond.shape[1]:
+ last_vector = unconditional_conditioning[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
+ unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
+ elif unconditional_conditioning.shape[1] > cond.shape[1]:
+ unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
+
+ if self.mask is not None:
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
+ x_dec = img_orig * self.mask + self.nmask * x_dec
+
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
+ if image_conditioning is not None:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ if self.mask is not None:
+ self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
+ else:
+ self.last_latent = res[1]
+
+ sd_samplers_common.store_latent(self.last_latent)
+
+ self.step += 1
+ state.sampling_step = self.step
+ shared.total_tqdm.update()
+
+ return res
+
+ def initialize(self, p):
+ self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ if self.eta != 0.0:
+ p.extra_generation_params["Eta DDIM"] = self.eta
+
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
+ if hasattr(self.sampler, fieldname):
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
+
+ self.mask = p.mask if hasattr(p, 'mask') else None
+ self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == math.floor(valid_step):
+ return int(valid_step) + 1
+
+ return num_steps
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+ steps = self.adjust_steps_if_invalid(p, steps)
+ self.initialize(p)
+
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
+ x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
+
+ self.init_latent = x
+ self.last_latent = x
+ self.step = 0
+
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ self.initialize(p)
+
+ self.init_latent = None
+ self.last_latent = x
+ self.step = 0
+
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
+
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
+ if image_conditioning is not None:
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
+
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+
+ return samples_ddim
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
new file mode 100644
index 00000000..aa7f106b
--- /dev/null
+++ b/modules/sd_samplers_kdiffusion.py
@@ -0,0 +1,298 @@
+from collections import deque
+import torch
+import inspect
+import k_diffusion.sampling
+from modules import prompt_parser, devices, sd_samplers_common
+
+from modules.shared import opts, state
+import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+
+samplers_k_diffusion = [
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler', 'sample_euler', ['k_euler'], {}),
+ ('LMS', 'sample_lms', ['k_lms'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+]
+
+samplers_data_k_diffusion = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_k_diffusion
+ if hasattr(k_diffusion.sampling, funcname)
+]
+
+sampler_extra_params = {
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+}
+
+
+class CFGDenoiser(torch.nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.step = 0
+
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+
+ if tensor.shape[1] == uncond.shape[1]:
+ cond_in = torch.cat([tensor, uncond])
+
+ if shared.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
+ else:
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
+
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+
+ devices.test_for_nans(x_out, "unet")
+
+ if opts.live_preview_content == "Prompt":
+ sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
+ elif opts.live_preview_content == "Negative prompt":
+ sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
+
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+
+ if self.mask is not None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+
+ self.step += 1
+
+ return denoised
+
+
+class TorchHijack:
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.randn_like
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+
+ if x.device.type == 'mps':
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
+ else:
+ return torch.randn_like(x)
+
+
+class KDiffusionSampler:
+ def __init__(self, funcname, sd_model):
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
+ self.funcname = funcname
+ self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.extra_params = sampler_extra_params.get(funcname, [])
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ self.sampler_noises = None
+ self.stop_at = None
+ self.eta = None
+ self.config = None
+ self.last_latent = None
+
+ self.conditioning_key = sd_model.model.conditioning_key
+
+ def callback_state(self, d):
+ step = d['i']
+ latent = d["denoised"]
+ if opts.live_preview_content == "Combined":
+ sd_samplers_common.store_latent(latent)
+ self.last_latent = latent
+
+ if self.stop_at is not None and step > self.stop_at:
+ raise sd_samplers_common.InterruptedException
+
+ state.sampling_step = step
+ shared.total_tqdm.update()
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except sd_samplers_common.InterruptedException:
+ return self.last_latent
+
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def initialize(self, p):
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.step = 0
+ self.eta = p.eta if p.eta is not None else opts.eta_ancestral
+
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+
+ extra_params_kwargs = {}
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
+
+ if 'eta' in inspect.signature(self.func).parameters:
+ if self.eta != 1.0:
+ p.extra_generation_params["Eta"] = self.eta
+
+ extra_params_kwargs['eta'] = self.eta
+
+ return extra_params_kwargs
+
+ def get_sigmas(self, p, steps):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
+ discard_next_to_last_sigma = True
+ p.extra_generation_params["Discard penultimate sigma"] = True
+
+ steps += 1 if discard_next_to_last_sigma else 0
+
+ if p.sampler_noise_scheduler_override:
+ sigmas = p.sampler_noise_scheduler_override(steps)
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+ else:
+ sigmas = self.model_wrap.get_sigmas(steps)
+
+ if discard_next_to_last_sigma:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+ return sigmas
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+
+ sigmas = self.get_sigmas(p, steps)
+
+ sigma_sched = sigmas[steps - t_enc - 1:]
+ xi = x + noise * sigma_sched[0]
+
+ extra_params_kwargs = self.initialize(p)
+ if 'sigma_min' in inspect.signature(self.func).parameters:
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
+ if 'sigma_max' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
+ if 'n' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
+ if 'sigma_sched' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_sched'] = sigma_sched
+ if 'sigmas' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigmas'] = sigma_sched
+
+ self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
+ steps = steps or p.steps
+
+ sigmas = self.get_sigmas(p, steps)
+
+ x = x * sigmas[0]
+
+ extra_params_kwargs = self.initialize(p)
+ if 'sigma_min' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
+ extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
+ if 'n' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['n'] = steps
+ else:
+ extra_params_kwargs['sigmas'] = sigmas
+
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ return samples
+
diff --git a/modules/shared.py b/modules/shared.py
index eb04e811..69634fd8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -105,6 +105,8 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
+parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
+
script_loading.preload_extensions(extensions.extensions_dir, parser)
@@ -127,12 +129,13 @@ restricted_opts = {
ui_reorder_categories = [
"inpaint",
"sampler",
+ "checkboxes",
+ "hires_fix",
"dimensions",
"cfg",
"seed",
- "checkboxes",
- "hires_fix",
"batch",
+ "override_settings",
"scripts",
]
@@ -346,10 +349,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
- "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
- "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
+ "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
+ "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
+ "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
@@ -440,7 +443,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
- "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"),
@@ -605,11 +608,37 @@ class Options:
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+ def cast_value(self, key, value):
+ """casts an arbitrary to the same type as this setting's value with key
+ Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
+ """
+
+ if value is None:
+ return None
+
+ default_value = self.data_labels[key].default
+ if default_value is None:
+ default_value = getattr(self, key, None)
+ if default_value is None:
+ return None
+
+ expected_type = type(default_value)
+ if expected_type == bool and value == "False":
+ value = False
+ else:
+ value = expected_type(value)
+
+ return value
+
+
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
+settings_components = None
+"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
+
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False},
diff --git a/modules/txt2img.py b/modules/txt2img.py
index c06f9f9d..a938eaaa 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,5 +1,6 @@
import modules.scripts
from modules import sd_samplers
+from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
@@ -8,7 +9,9 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, *args):
+
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+ override_settings = create_override_settings_dict(override_settings_texts)
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -40,7 +43,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_resize_y=hr_resize_y,
hr_sampler=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else '---',
hr_prompt=hr_prompt,
- hr_negative_prompt=hr_negative_prompt
+ hr_negative_prompt=hr_negative_prompt,
+ override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
diff --git a/modules/ui.py b/modules/ui.py
index c50b4c9c..dca67fed 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -380,6 +380,7 @@ def apply_setting(key, value):
opts.save(shared.config_filename)
return getattr(opts, key)
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@@ -433,6 +434,18 @@ def get_value_for_setting(key):
return gr.update(value=value, **args)
+def create_override_settings_dropdown(tabname, row):
+ dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
+
+ dropdown.change(
+ fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
+ inputs=[dropdown],
+ outputs=[dropdown],
+ )
+
+ return dropdown
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -514,6 +527,10 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+ elif category == "override_settings":
+ with FormRow(elem_id="txt2img_override_settings_row") as row:
+ override_settings = create_override_settings_dropdown('txt2img', row)
+
elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
@@ -535,7 +552,6 @@ def create_ui():
)
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
- parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -569,6 +585,8 @@ def create_ui():
hr_sampler_index,
hr_prompt,
hr_negative_prompt,
+ override_settings,
+
] + custom_inputs,
outputs=[
@@ -632,6 +650,9 @@ def create_ui():
*modules.scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings,
+ ))
txt2img_preview_params = [
txt2img_prompt,
@@ -779,6 +800,10 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+ elif category == "override_settings":
+ with FormRow(elem_id="img2img_override_settings_row") as row:
+ override_settings = create_override_settings_dropdown('img2img', row)
+
elif category == "scripts":
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
@@ -813,7 +838,6 @@ def create_ui():
)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
- parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -866,7 +890,8 @@ def create_ui():
inpainting_mask_invert,
img2img_batch_input_dir,
img2img_batch_output_dir,
- img2img_batch_inpaint_mask_dir
+ img2img_batch_inpaint_mask_dir,
+ override_settings,
] + custom_inputs,
outputs=[
img2img_gallery,
@@ -954,6 +979,9 @@ def create_ui():
]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings,
+ ))
modules.scripts.scripts_current = None
@@ -971,7 +999,11 @@ def create_ui():
html2 = gr.HTML()
with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
- parameters_copypaste.bind_buttons(buttons, image, generation_info)
+
+ for tabname, button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
+ ))
image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo),
@@ -1380,6 +1412,7 @@ def create_ui():
components = []
component_dict = {}
+ shared.settings_components = component_dict
script_callbacks.ui_settings_callback()
opts.reorder()
@@ -1546,8 +1579,7 @@ def create_ui():
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
- parameters_copypaste.integrate_settings_paste_fields(component_dict)
- parameters_copypaste.run_bind()
+ parameters_copypaste.connect_paste_params_buttons()
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 9405ac1f..fd047f31 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -198,5 +198,9 @@ Requested path was: {f}
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ for paste_tabname, paste_button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
+ ))
+
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index a6799171..04097a79 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -14,6 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
shared.refresh_checkpoints()
def list_items(self):
+ checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items():
path, ext = os.path.splitext(checkpoint.filename)
previews = [path + ".png", path + ".preview.png"]
@@ -28,7 +29,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"name": checkpoint.name_for_extra,
"filename": path,
"preview": preview,
- "search_term": self.search_terms_from_path(checkpoint.filename),
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": path + ".png",
}
diff --git a/webui.py b/webui.py
index 0d0b8364..5b5c2139 100644
--- a/webui.py
+++ b/webui.py
@@ -52,6 +52,9 @@ else:
def check_versions():
+ if shared.cmd_opts.skip_version_check:
+ return
+
expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
@@ -59,7 +62,10 @@ def check_versions():
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
-Beware that this will cause a lot of large files to be downloaded.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
""".strip())
expected_xformers_version = "0.0.16rc425"
@@ -71,6 +77,8 @@ Beware that this will cause a lot of large files to be downloaded.
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
""".strip())