aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/generation_parameters_copypaste.py5
-rw-r--r--modules/img2img.py6
-rw-r--r--modules/launch_utils.py29
-rw-r--r--modules/processing.py3
-rw-r--r--modules/sd_hijack.py4
-rw-r--r--modules/sd_hijack_inpainting.py95
-rw-r--r--modules/sd_models.py3
-rw-r--r--modules/sd_samplers.py19
-rw-r--r--modules/sd_samplers_cfg_denoiser.py203
-rw-r--r--modules/sd_samplers_common.py138
-rw-r--r--modules/sd_samplers_compvis.py232
-rw-r--r--modules/sd_samplers_kdiffusion.py379
-rw-r--r--modules/sd_samplers_timesteps.py147
-rw-r--r--modules/sd_samplers_timesteps_impl.py135
-rw-r--r--modules/sd_vae.py71
-rw-r--r--modules/shared.py21
-rw-r--r--modules/txt2img.py8
-rw-r--r--modules/ui.py34
-rw-r--r--modules/ui_extra_networks_checkpoints_user_metadata.py8
19 files changed, 788 insertions, 752 deletions
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 6711ca16..20e30b53 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -416,10 +416,15 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
return res
if override_settings_component is not None:
+ already_handled_fields = {key: 1 for _, key in paste_fields}
+
def paste_settings(params):
vals = {}
for param_name, setting_name in infotext_to_setting_name_mapping:
+ if param_name in already_handled_fields:
+ continue
+
v = params.get(param_name, None)
if v is None:
continue
diff --git a/modules/img2img.py b/modules/img2img.py
index d8e1c534..e06ac1d6 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -6,7 +6,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr
-from modules import sd_samplers, images as imgutil
+from modules import images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
@@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p)
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
+ sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index f77b577a..5be30a18 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -139,6 +139,27 @@ def check_run_python(code: str) -> bool:
return result.returncode == 0
+def git_fix_workspace(dir, name):
+ run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
+ run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
+ return
+
+
+def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
+ try:
+ return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
+ except RuntimeError:
+ pass
+
+ if not autofix:
+ return None
+
+ print(f"{errdesc}, attempting autofix...")
+ git_fix_workspace(dir, name)
+
+ return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
+
+
def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful
@@ -146,12 +167,14 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None:
return
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
+ current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash:
return
- run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
+ run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
+
+ run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
+
return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
diff --git a/modules/processing.py b/modules/processing.py
index ec66fd8e..b635cc74 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1119,9 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
- if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
- img2img_sampler_name = 'DDIM'
-
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
if self.latent_scale_mode is not None:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 9ad98199..46652fbd 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
-sd_hijack_inpainting.do_inpainting_hijack()
-
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
deleted file mode 100644
index 2d44b856..00000000
--- a/modules/sd_hijack_inpainting.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import torch
-
-import ldm.models.diffusion.ddpm
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-
-from ldm.models.diffusion.ddim import noise_like
-from ldm.models.diffusion.sampling_util import norm_thresholding
-
-
-@torch.no_grad()
-def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
- b, *_, device = *x.shape, x.device
-
- def get_model_output(x, t):
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
-
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = {}
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- else:
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- c_in = torch.cat([unconditional_conditioning, c])
-
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- return e_t
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
-
- def get_x_prev_and_pred_x0(e_t, index):
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- if dynamic_threshold is not None:
- pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- e_t = get_model_output(x, t)
- if len(old_eps) == 0:
- # Pseudo Improved Euler (2nd order)
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
- e_t_next = get_model_output(x_prev, t_next)
- e_t_prime = (e_t + e_t_next) / 2
- elif len(old_eps) == 1:
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
- elif len(old_eps) == 2:
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
- elif len(old_eps) >= 3:
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
-
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
-
- return x_prev, pred_x0, e_t
-
-
-def do_inpainting_hijack():
- ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 981aa93d..a97af215 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -372,7 +372,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")
@@ -715,6 +715,7 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
+ sd_unet.apply_unet()
return sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index bea2684c..45faae62 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,17 +1,18 @@
-from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
+from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
- *sd_samplers_compvis.samplers_data_compvis,
+ *sd_samplers_timesteps.samplers_data_timesteps,
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers_map = {}
+samplers_hidden = {}
def find_sampler_config(name):
@@ -38,13 +39,11 @@ def create_sampler(name, model):
def set_samplers():
- global samplers, samplers_for_img2img
+ global samplers, samplers_for_img2img, samplers_hidden
- hidden = set(shared.opts.hide_samplers)
- hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
-
- samplers = [x for x in all_samplers if x.name not in hidden]
- samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
+ samplers_hidden = set(shared.opts.hide_samplers)
+ samplers = all_samplers
+ samplers_for_img2img = all_samplers
samplers_map.clear()
for sampler in all_samplers:
@@ -53,4 +52,8 @@ def set_samplers():
samplers_map[alias.lower()] = sampler.name
+def visible_sampler_names():
+ return [x.name for x in samplers if x.name not in samplers_hidden]
+
+
set_samplers()
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
new file mode 100644
index 00000000..d826222c
--- /dev/null
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -0,0 +1,203 @@
+import torch
+from modules import prompt_parser, devices, sd_samplers_common
+
+from modules.shared import opts, state
+import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
+
+
+def catenate_conds(conds):
+ if not isinstance(conds[0], dict):
+ return torch.cat(conds)
+
+ return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
+
+
+def subscript_cond(cond, a, b):
+ if not isinstance(cond, dict):
+ return cond[a:b]
+
+ return {key: vec[a:b] for key, vec in cond.items()}
+
+
+def pad_cond(tensor, repeats, empty):
+ if not isinstance(tensor, dict):
+ return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
+
+ tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
+ return tensor
+
+
+class CFGDenoiser(torch.nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+
+ def __init__(self, model, sampler):
+ super().__init__()
+ self.inner_model = model
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.step = 0
+ self.image_cfg_scale = None
+ self.padded_cond_uncond = False
+ self.sampler = sampler
+
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
+
+ return denoised
+
+ def get_pred_x0(self, x_in, x_out, sigma):
+ return x_out
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+
+ if self.mask is not None:
+ x = self.init_latent * self.mask + self.nmask * x
+
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+
+ if shared.sd_model.model.conditioning_key == "crossattn-adm":
+ image_uncond = torch.zeros_like(image_cond)
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
+ else:
+ image_uncond = image_cond
+ if isinstance(uncond, dict):
+ make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
+ else:
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
+
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
+
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+ tensor = denoiser_params.text_cond
+ uncond = denoiser_params.text_uncond
+ skip_uncond = False
+
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ self.padded_cond_uncond = False
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = pad_cond(tensor, -num_repeats, empty)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = pad_cond(uncond, num_repeats, empty)
+ self.padded_cond_uncond = True
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
+ cond_in = catenate_conds([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = catenate_conds([tensor, uncond])
+
+ if shared.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
+ else:
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+
+ if not is_edit_model:
+ c_crossattn = subscript_cond(tensor, a, b)
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
+
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
+
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
+ cfg_denoised_callback(denoised_params)
+
+ devices.test_for_nans(x_out, "unet")
+
+ if is_edit_model:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+
+ self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
+
+ if opts.live_preview_content == "Prompt":
+ preview = self.sampler.last_latent
+ elif opts.live_preview_content == "Negative prompt":
+ preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
+ else:
+ preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
+
+ sd_samplers_common.store_latent(preview)
+
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
+ self.step += 1
+ return denoised
+
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 92bf0ca1..15f27970 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -1,9 +1,11 @@
-from collections import namedtuple
+import inspect
+from collections import namedtuple, deque
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
+import k_diffusion.sampling
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -155,3 +157,137 @@ def apply_refiner(sampler):
return True
+class TorchHijack:
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.randn_like
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
+
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+
+ return devices.randn_like(x)
+
+
+class Sampler:
+ def __init__(self, funcname):
+ self.funcname = funcname
+ self.func = funcname
+ self.extra_params = []
+ self.sampler_noises = None
+ self.stop_at = None
+ self.eta = None
+ self.config = None # set by the function calling the constructor
+ self.last_latent = None
+ self.s_min_uncond = None
+ self.s_churn = 0.0
+ self.s_tmin = 0.0
+ self.s_tmax = float('inf')
+ self.s_noise = 1.0
+
+ self.eta_option_field = 'eta_ancestral'
+ self.eta_infotext_field = 'Eta'
+
+ self.conditioning_key = shared.sd_model.model.conditioning_key
+
+ self.model_wrap = None
+ self.model_wrap_cfg = None
+
+ def callback_state(self, d):
+ step = d['i']
+
+ if self.stop_at is not None and step > self.stop_at:
+ raise InterruptedException
+
+ state.sampling_step = step
+ shared.total_tqdm.update()
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
+ except InterruptedException:
+ return self.last_latent
+
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def initialize(self, p) -> dict:
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.step = 0
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
+ self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
+
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+
+ extra_params_kwargs = {}
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
+
+ if 'eta' in inspect.signature(self.func).parameters:
+ if self.eta != 1.0:
+ p.extra_generation_params[self.eta_infotext_field] = self.eta
+
+ extra_params_kwargs['eta'] = self.eta
+
+ if len(self.extra_params) > 0:
+ s_churn = getattr(opts, 's_churn', p.s_churn)
+ s_tmin = getattr(opts, 's_tmin', p.s_tmin)
+ s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
+ s_noise = getattr(opts, 's_noise', p.s_noise)
+
+ if s_churn != self.s_churn:
+ extra_params_kwargs['s_churn'] = s_churn
+ p.s_churn = s_churn
+ p.extra_generation_params['Sigma churn'] = s_churn
+ if s_tmin != self.s_tmin:
+ extra_params_kwargs['s_tmin'] = s_tmin
+ p.s_tmin = s_tmin
+ p.extra_generation_params['Sigma tmin'] = s_tmin
+ if s_tmax != self.s_tmax:
+ extra_params_kwargs['s_tmax'] = s_tmax
+ p.s_tmax = s_tmax
+ p.extra_generation_params['Sigma tmax'] = s_tmax
+ if s_noise != self.s_noise:
+ extra_params_kwargs['s_noise'] = s_noise
+ p.s_noise = s_noise
+ p.extra_generation_params['Sigma noise'] = s_noise
+
+ return extra_params_kwargs
+
+ def create_noise_sampler(self, x, sigmas, p):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ if shared.opts.no_dpmpp_sde_batch_determinism:
+ return None
+
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+
+
+
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index 2eeec18a..e69de29b 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -1,232 +0,0 @@
-import math
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-
-import numpy as np
-import torch
-
-from modules.shared import state
-from modules import sd_samplers_common, prompt_parser, shared
-import modules.models.diffusion.uni_pc
-
-
-samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
- sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
- sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
-]
-
-
-class VanillaStableDiffusionSampler:
- def __init__(self, constructor, sd_model):
- self.p = None
- self.sampler = constructor(shared.sd_model)
- self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
- self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
- self.orig_p_sample_ddim = None
- if self.is_plms:
- self.orig_p_sample_ddim = self.sampler.p_sample_plms
- elif self.is_ddim:
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.sampler_noises = None
- self.steps = None
- self.step = 0
- self.stop_at = None
- self.eta = None
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def number_of_needed_noises(self, p):
- return 0
-
- def launch_sampling(self, steps, func):
- self.steps = steps
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
- x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
-
- res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
-
- x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
-
- return res
-
- def update_inner_model(self):
- self.sampler.model = shared.sd_model
-
- def before_sample(self, x, ts, cond, unconditional_conditioning):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
-
- if self.stop_at is not None and self.step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- # Have to unwrap the inpainting conditioning here to perform pre-processing
- image_conditioning = None
- uc_image_conditioning = None
- if isinstance(cond, dict):
- if self.conditioning_key == "crossattn-adm":
- image_conditioning = cond["c_adm"]
- uc_image_conditioning = unconditional_conditioning["c_adm"]
- else:
- image_conditioning = cond["c_concat"][0]
- cond = cond["c_crossattn"][0]
- unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
-
- assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
- cond = tensor
-
- # for DDIM, shapes must match, we can't just process cond and uncond independently;
- # filling unconditional_conditioning with repeats of the last vector to match length is
- # not 100% correct but should work well enough
- if unconditional_conditioning.shape[1] < cond.shape[1]:
- last_vector = unconditional_conditioning[:, -1:]
- last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
- unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
- elif unconditional_conditioning.shape[1] > cond.shape[1]:
- unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
-
- if self.mask is not None:
- img_orig = self.sampler.model.q_sample(self.init_latent, ts)
- x = img_orig * self.mask + self.nmask * x
-
- # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
- # Note that they need to be lists because it just concatenates them later.
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
- unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
- else:
- cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- return x, ts, cond, unconditional_conditioning
-
- def update_step(self, last_latent):
- if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
- else:
- self.last_latent = last_latent
-
- sd_samplers_common.store_latent(self.last_latent)
-
- self.step += 1
- state.sampling_step = self.step
- shared.total_tqdm.update()
-
- def after_sample(self, x, ts, cond, uncond, res):
- if not self.is_unipc:
- self.update_step(res[1])
-
- return x, ts, cond, uncond, res
-
- def unipc_after_update(self, x, model_x):
- self.update_step(x)
-
- def initialize(self, p):
- self.p = p
-
- if self.is_ddim:
- self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
- else:
- self.eta = 0.0
-
- if self.eta != 0.0:
- p.extra_generation_params["Eta DDIM"] = self.eta
-
- if self.is_unipc:
- keys = [
- ('UniPC variant', 'uni_pc_variant'),
- ('UniPC skip type', 'uni_pc_skip_type'),
- ('UniPC order', 'uni_pc_order'),
- ('UniPC lower order final', 'uni_pc_lower_order_final'),
- ]
-
- for name, key in keys:
- v = getattr(shared.opts, key)
- if v != shared.opts.get_default(key):
- p.extra_generation_params[name] = v
-
- for fieldname in ['p_sample_ddim', 'p_sample_plms']:
- if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
- if self.is_unipc:
- self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
-
- self.mask = p.mask if hasattr(p, 'mask') else None
- self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
-
- def adjust_steps_if_invalid(self, p, num_steps):
- if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
- if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
- num_steps = shared.opts.uni_pc_order
- valid_step = 999 / (1000 // num_steps)
- if valid_step == math.floor(valid_step):
- return int(valid_step) + 1
-
- return num_steps
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
- steps = self.adjust_steps_if_invalid(p, steps)
- self.initialize(p)
-
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
-
- self.init_latent = x
- self.last_latent = x
- self.step = 0
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
- else:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- self.initialize(p)
-
- self.init_latent = None
- self.last_latent = x
- self.step = 0
-
- steps = self.adjust_steps_if_invalid(p, steps or p.steps)
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
- else:
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
-
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
-
- return samples_ddim
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 46da0a97..3ff4b634 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,17 +1,16 @@
-from collections import deque
import torch
import inspect
import k_diffusion.sampling
-from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
+from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
-from modules.processing import StableDiffusionProcessing
-from modules.shared import opts, state
+from modules.shared import opts
import modules.shared as shared
-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
-from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
+ ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
@@ -28,10 +27,6 @@ samplers_k_diffusion = [
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
- ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
- ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
- ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
]
@@ -57,342 +52,17 @@ k_diffusion_scheduler = {
}
-def catenate_conds(conds):
- if not isinstance(conds[0], dict):
- return torch.cat(conds)
-
- return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
-
-
-def subscript_cond(cond, a, b):
- if not isinstance(cond, dict):
- return cond[a:b]
-
- return {key: vec[a:b] for key, vec in cond.items()}
-
-
-def pad_cond(tensor, repeats, empty):
- if not isinstance(tensor, dict):
- return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
-
- tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
- return tensor
-
-
-class CFGDenoiser(torch.nn.Module):
- """
- Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
- that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
- instead of one. Originally, the second prompt is just an empty string, but we use non-empty
- negative prompt.
- """
-
- def __init__(self, sampler):
- super().__init__()
- self.sampler = sampler
- self.model_wrap = None
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.steps = None
- self.step = 0
- self.image_cfg_scale = None
- self.padded_cond_uncond = False
- self.p = None
-
- @property
- def inner_model(self):
- if self.model_wrap is None:
- denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
- self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
-
- return self.model_wrap
-
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
-
- return denoised
-
- def combine_denoised_for_edit_model(self, x_out, cond_scale):
- out_cond, out_img_cond, out_uncond = x_out.chunk(3)
- denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
-
- return denoised
-
- def update_inner_model(self):
- self.model_wrap = None
-
- c, uc = self.p.get_conds()
- self.sampler.sampler_extra_args['cond'] = c
- self.sampler.sampler_extra_args['uncond'] = uc
-
- def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
-
- if sd_samplers_common.apply_refiner(self):
- cond = self.sampler.sampler_extra_args['cond']
- uncond = self.sampler.sampler_extra_args['uncond']
-
- # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
- # so is_edit_model is set to False to support AND composition.
- is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
-
- assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
-
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
-
- if shared.sd_model.model.conditioning_key == "crossattn-adm":
- image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
- else:
- image_uncond = image_cond
- if isinstance(uncond, dict):
- make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
- else:
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
-
- if not is_edit_model:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
- else:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
-
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
- tensor = denoiser_params.text_cond
- uncond = denoiser_params.text_uncond
- skip_uncond = False
-
- # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
- if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
- skip_uncond = True
- x_in = x_in[:-batch_size]
- sigma_in = sigma_in[:-batch_size]
-
- self.padded_cond_uncond = False
- if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
- empty = shared.sd_model.cond_stage_model_empty_prompt
- num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
-
- if num_repeats < 0:
- tensor = pad_cond(tensor, -num_repeats, empty)
- self.padded_cond_uncond = True
- elif num_repeats > 0:
- uncond = pad_cond(uncond, num_repeats, empty)
- self.padded_cond_uncond = True
-
- if tensor.shape[1] == uncond.shape[1] or skip_uncond:
- if is_edit_model:
- cond_in = catenate_conds([tensor, uncond, uncond])
- elif skip_uncond:
- cond_in = tensor
- else:
- cond_in = catenate_conds([tensor, uncond])
-
- if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
-
- if not is_edit_model:
- c_crossattn = subscript_cond(tensor, a, b)
- else:
- c_crossattn = torch.cat([tensor[a:b]], uncond)
-
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
-
- if not skip_uncond:
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
-
- denoised_image_indexes = [x[0][0] for x in conds_list]
- if skip_uncond:
- fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
- x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
-
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
- cfg_denoised_callback(denoised_params)
-
- devices.test_for_nans(x_out, "unet")
-
- if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
- elif opts.live_preview_content == "Negative prompt":
- sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
-
- if is_edit_model:
- denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
- elif skip_uncond:
- denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
- else:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
-
- if self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
-
- after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
- cfg_after_cfg_callback(after_cfg_callback_params)
- denoised = after_cfg_callback_params.x
-
- self.step += 1
- return denoised
-
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- return devices.randn_like(x)
+class KDiffusionSampler(sd_samplers_common.Sampler):
+ def __init__(self, funcname, sd_model):
+ super().__init__(funcname)
-class KDiffusionSampler:
- def __init__(self, funcname, sd_model):
- self.p = None
- self.funcname = funcname
- self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
- self.sampler_extra_args = {}
- self.model_wrap_cfg = CFGDenoiser(self)
- self.model_wrap = self.model_wrap_cfg.inner_model
- self.sampler_noises = None
- self.stop_at = None
- self.eta = None
- self.config = None # set by the function calling the constructor
- self.last_latent = None
- self.s_min_uncond = None
-
- # NOTE: These are also defined in the StableDiffusionProcessing class.
- # They should have been here to begin with but we're going to
- # leave that class __init__ signature alone.
- self.s_churn = 0.0
- self.s_tmin = 0.0
- self.s_tmax = float('inf')
- self.s_noise = 1.0
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def callback_state(self, d):
- step = d['i']
- latent = d["denoised"]
- if opts.live_preview_content == "Combined":
- sd_samplers_common.store_latent(latent)
- self.last_latent = latent
-
- if self.stop_at is not None and step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- state.sampling_step = step
- shared.total_tqdm.update()
-
- def launch_sampling(self, steps, func):
- self.model_wrap_cfg.steps = steps
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except RecursionError:
- print(
- 'Encountered RecursionError during sampling, returning last latent. '
- 'rho >5 with a polyexponential scheduler may cause this error. '
- 'You should try to use a smaller rho value instead.'
- )
- return self.last_latent
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def number_of_needed_noises(self, p):
- return p.steps
-
- def initialize(self, p: StableDiffusionProcessing):
- self.p = p
- self.model_wrap_cfg.p = p
- self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
- self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.step = 0
- self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
- self.eta = p.eta if p.eta is not None else opts.eta_ancestral
- self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
-
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
-
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
-
- if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
- p.extra_generation_params["Eta"] = self.eta
-
- extra_params_kwargs['eta'] = self.eta
-
- if len(self.extra_params) > 0:
- s_churn = getattr(opts, 's_churn', p.s_churn)
- s_tmin = getattr(opts, 's_tmin', p.s_tmin)
- s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
- s_noise = getattr(opts, 's_noise', p.s_noise)
-
- if s_churn != self.s_churn:
- extra_params_kwargs['s_churn'] = s_churn
- p.s_churn = s_churn
- p.extra_generation_params['Sigma churn'] = s_churn
- if s_tmin != self.s_tmin:
- extra_params_kwargs['s_tmin'] = s_tmin
- p.s_tmin = s_tmin
- p.extra_generation_params['Sigma tmin'] = s_tmin
- if s_tmax != self.s_tmax:
- extra_params_kwargs['s_tmax'] = s_tmax
- p.s_tmax = s_tmax
- p.extra_generation_params['Sigma tmax'] = s_tmax
- if s_noise != self.s_noise:
- extra_params_kwargs['s_noise'] = s_noise
- p.s_noise = s_noise
- p.extra_generation_params['Sigma noise'] = s_noise
-
- return extra_params_kwargs
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
+
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
+ self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self)
def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -444,22 +114,12 @@ class KDiffusionSampler:
return sigmas
- def create_noise_sampler(self, x, sigmas, p):
- """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
- if shared.opts.no_dpmpp_sde_batch_determinism:
- return None
-
- from k_diffusion.sampling import BrownianTreeNoiseSampler
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
- current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
- return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
-
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
-
sigma_sched = sigmas[steps - t_enc - 1:]
+
xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p)
@@ -508,12 +168,14 @@ class KDiffusionSampler:
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = steps
+
if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in parameters:
- extra_params_kwargs['n'] = steps
- else:
+
+ if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigmas
if self.config.options.get('brownian_noise', False):
@@ -535,3 +197,4 @@ class KDiffusionSampler:
return samples
+
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
new file mode 100644
index 00000000..d89d0efb
--- /dev/null
+++ b/modules/sd_samplers_timesteps.py
@@ -0,0 +1,147 @@
+import torch
+import inspect
+from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
+from modules.sd_samplers_cfg_denoiser import CFGDenoiser
+
+from modules.shared import opts
+import modules.shared as shared
+
+samplers_timesteps = [
+ ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
+ ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
+ ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
+]
+
+
+samplers_data_timesteps = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_timesteps
+]
+
+
+class CompVisTimestepsDenoiser(torch.nn.Module):
+ def __init__(self, model, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.inner_model = model
+
+ def forward(self, input, timesteps, **kwargs):
+ return self.inner_model.apply_model(input, timesteps, **kwargs)
+
+
+class CompVisTimestepsVDenoiser(torch.nn.Module):
+ def __init__(self, model, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.inner_model = model
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+
+ def forward(self, input, timesteps, **kwargs):
+ model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
+ e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
+ return e_t
+
+
+class CFGDenoiserTimesteps(CFGDenoiser):
+
+ def __init__(self, model, sampler):
+ super().__init__(model, sampler)
+
+ self.alphas = model.inner_model.alphas_cumprod
+
+ def get_pred_x0(self, x_in, x_out, sigma):
+ ts = int(sigma.item())
+
+ s_in = x_in.new_ones([x_in.shape[0]])
+ a_t = self.alphas[ts].item() * s_in
+ sqrt_one_minus_at = (1 - a_t).sqrt()
+
+ pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
+
+ return pred_x0
+
+
+class CompVisSampler(sd_samplers_common.Sampler):
+ def __init__(self, funcname, sd_model):
+ super().__init__(funcname)
+
+ self.eta_option_field = 'eta_ddim'
+ self.eta_infotext_field = 'Eta DDIM'
+
+ denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
+ self.model_wrap = denoiser(sd_model)
+ self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
+
+ def get_timesteps(self, p, steps):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
+ discard_next_to_last_sigma = True
+ p.extra_generation_params["Discard penultimate sigma"] = True
+
+ steps += 1 if discard_next_to_last_sigma else 0
+
+ timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
+
+ return timesteps
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+
+ timesteps = self.get_timesteps(p, steps)
+ timesteps_sched = timesteps[:t_enc]
+
+ alphas_cumprod = shared.sd_model.alphas_cumprod
+ sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
+ sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
+
+ xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'timesteps' in parameters:
+ extra_params_kwargs['timesteps'] = timesteps_sched
+ if 'is_img2img' in parameters:
+ extra_params_kwargs['is_img2img'] = True
+
+ self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
+ extra_args = {
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps = steps or p.steps
+ timesteps = self.get_timesteps(p, steps)
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'timesteps' in parameters:
+ extra_params_kwargs['timesteps'] = timesteps
+
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ return samples
+
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
new file mode 100644
index 00000000..48d7e649
--- /dev/null
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -0,0 +1,135 @@
+import torch
+import tqdm
+import k_diffusion.sampling
+import numpy as np
+
+from modules import shared
+from modules.models.diffusion.uni_pc import uni_pc
+
+
+@torch.no_grad()
+def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+ sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+
+ e_t = model(x, timesteps[index].item() * s_in, **extra_args)
+
+ a_t = alphas[index].item() * s_in
+ a_prev = alphas_prev[index].item() * s_in
+ sigma_t = sigmas[index].item() * s_in
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
+ noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
+ x = a_prev.sqrt() * pred_x0 + dir_xt + noise
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+@torch.no_grad()
+def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ old_eps = []
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = alphas[index].item() * s_in
+ a_prev = alphas_prev[index].item() * s_in
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev).sqrt() * e_t
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev, pred_x0
+
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+ ts = timesteps[index].item() * s_in
+ t_next = timesteps[max(index - 1, 0)].item() * s_in
+
+ e_t = model(x, ts, **extra_args)
+
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = model(x_prev, t_next, **extra_args)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ else:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+
+ x = x_prev
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+class UniPCCFG(uni_pc.UniPC):
+ def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
+ super().__init__(None, *args, **kwargs)
+
+ def after_update(x, model_x):
+ callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
+ self.index += 1
+
+ self.cfg_model = cfg_model
+ self.extra_args = extra_args
+ self.callback = callback
+ self.index = 0
+ self.after_update = after_update
+
+ def get_model_input_time(self, t_continuous):
+ return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
+
+ def model(self, x, t):
+ t_input = self.get_model_input_time(t)
+
+ res = self.cfg_model(x, t_input, **self.extra_args)
+
+ return res
+
+
+def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+
+ ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
+ unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
+ x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
+
+ return x
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 0bd5e19b..38bcb840 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,5 +1,7 @@
import os
import collections
+from dataclasses import dataclass
+
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
import glob
from copy import deepcopy
@@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file):
return None
-def resolve_vae(checkpoint_file):
- if shared.cmd_opts.vae_path is not None:
- return shared.cmd_opts.vae_path, 'from commandline argument'
+@dataclass
+class VaeResolution:
+ vae: str = None
+ source: str = None
+ resolved: bool = True
+
+ def tuple(self):
+ return self.vae, self.source
+
+
+def is_automatic():
+ return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
+
+
+def resolve_vae_from_setting() -> VaeResolution:
+ if shared.opts.sd_vae == "None":
+ return VaeResolution()
+
+ vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
+ if vae_from_options is not None:
+ return VaeResolution(vae_from_options, 'specified in settings')
+
+ if not is_automatic():
+ print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
+ return VaeResolution(resolved=False)
+
+
+def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
metadata = extra_networks.get_user_metadata(checkpoint_file)
vae_metadata = metadata.get("vae", None)
if vae_metadata is not None and vae_metadata != "Automatic":
if vae_metadata == "None":
- return None, None
+ return VaeResolution()
vae_from_metadata = vae_dict.get(vae_metadata, None)
if vae_from_metadata is not None:
- return vae_from_metadata, "from user metadata"
+ return VaeResolution(vae_from_metadata, "from user metadata")
+
+ return VaeResolution(resolved=False)
- is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
+def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
- return vae_near_checkpoint, 'found near the checkpoint'
+ return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
- if shared.opts.sd_vae == "None":
- return None, None
+ return VaeResolution(resolved=False)
- vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
- if vae_from_options is not None:
- return vae_from_options, 'specified in settings'
- if not is_automatic:
- print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
+def resolve_vae(checkpoint_file) -> VaeResolution:
+ if shared.cmd_opts.vae_path is not None:
+ return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
+
+ if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
+ return resolve_vae_from_setting()
+
+ res = resolve_vae_from_user_metadata(checkpoint_file)
+ if res.resolved:
+ return res
+
+ res = resolve_vae_near_checkpoint(checkpoint_file)
+ if res.resolved:
+ return res
+
+ res = resolve_vae_from_setting()
- return None, None
+ return res
def load_vae_dict(filename, map_location):
@@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
checkpoint_file = checkpoint_info.filename
if vae_file == unspecified:
- vae_file, vae_source = resolve_vae(checkpoint_file)
+ vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
else:
vae_source = "from function argument"
diff --git a/modules/shared.py b/modules/shared.py
index ed8395dc..2fd29904 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -422,6 +422,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
}))
options_templates.update(options_section(('system', "System"), {
+ "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
@@ -481,7 +482,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
"""),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
- "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
@@ -610,14 +611,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
- 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}),
- 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf"),
- 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
+ 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
+ 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
- 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
- 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
@@ -735,6 +736,10 @@ class Options:
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
+ # 1.6.0 VAE defaults
+ if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
+ self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
+
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 935ed418..8fa389b5 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,7 +1,7 @@
from contextlib import closing
import modules.scripts
-from modules import sd_samplers, processing
+from modules import processing
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts
import modules.shared as shared
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -25,7 +25,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_name=sd_samplers.samplers[sampler_index].name,
+ sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
@@ -42,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
- hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
+ hr_sampler_name=hr_sampler_name,
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
diff --git a/modules/ui.py b/modules/ui.py
index 1af6b4c8..e3753e97 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -29,7 +29,6 @@ import modules.shared as shared
import modules.images
from modules import prompt_parser
from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
from modules.generation_parameters_copypaste import image_from_url_text
create_setting_component = ui_settings.create_setting_component
@@ -41,6 +40,9 @@ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
+# Likewise, add explicit content-type header for certain missing image types
+mimetypes.add_type('image/webp', '.webp')
+
if not cmd_opts.share and not cmd_opts.listen:
# fix gradio phoning home
gradio.utils.version_check = lambda: None
@@ -357,14 +359,14 @@ def create_output_panel(tabname, outdir):
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
- sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
- return steps, sampler_index
+ return steps, sampler_name
def ordered_ui_categories():
@@ -405,13 +407,13 @@ def create_ui():
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
extra_tabs.__enter__()
- with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
+ with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+ steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
elif category == "dimensions":
with FormRow():
@@ -457,7 +459,7 @@ def create_ui():
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
- hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
+ hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
@@ -517,7 +519,7 @@ def create_ui():
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
- sampler_index,
+ sampler_name,
restore_faces,
tiling,
batch_count,
@@ -535,7 +537,7 @@ def create_ui():
hr_resize_x,
hr_resize_y,
hr_checkpoint_name,
- hr_sampler_index,
+ hr_sampler_name,
hr_prompt,
hr_negative_prompt,
override_settings,
@@ -580,7 +582,7 @@ def create_ui():
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
- (sampler_index, "Sampler"),
+ (sampler_name, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
(seed, "Seed"),
@@ -602,7 +604,7 @@ def create_ui():
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
(hr_checkpoint_name, "Hires checkpoint"),
- (hr_sampler_index, "Hires sampler"),
+ (hr_sampler_name, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
@@ -618,7 +620,7 @@ def create_ui():
toprow.prompt,
toprow.negative_prompt,
steps,
- sampler_index,
+ sampler_name,
cfg_scale,
seed,
width,
@@ -741,7 +743,7 @@ def create_ui():
for category in ordered_ui_categories():
if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
+ steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
elif category == "dimensions":
with FormRow():
@@ -873,7 +875,7 @@ def create_ui():
init_img_inpaint,
init_mask_inpaint,
steps,
- sampler_index,
+ sampler_name,
mask_blur,
mask_alpha,
inpainting_fill,
@@ -969,7 +971,7 @@ def create_ui():
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
- (sampler_index, "Sampler"),
+ (sampler_name, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"),
diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py
index 2c69aab8..25df0a80 100644
--- a/modules/ui_extra_networks_checkpoints_user_metadata.py
+++ b/modules/ui_extra_networks_checkpoints_user_metadata.py
@@ -1,6 +1,6 @@
import gradio as gr
-from modules import ui_extra_networks_user_metadata, sd_vae
+from modules import ui_extra_networks_user_metadata, sd_vae, shared
from modules.ui_common import create_refresh_button
@@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
self.write_user_metadata(name, user_metadata)
+ def update_vae(self, name):
+ if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
+ sd_vae.reload_vae_weights()
+
def put_values_into_components(self, name):
user_metadata = self.get_user_metadata(name)
values = super().put_values_into_components(name)
@@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
]
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
+ self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
+