From f1975b0213f5be400889ec04b3891d1cb571fe20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:01:07 +0300 Subject: initial refiner support --- modules/processing.py | 4 ++++ modules/sd_models.py | 18 +++++++++++++++++- modules/sd_samplers_common.py | 19 ++++++++++++++++++- modules/sd_samplers_compvis.py | 12 +++++++++++- modules/sd_samplers_kdiffusion.py | 30 ++++++++++++++++++++++++------ modules/shared.py | 2 ++ 6 files changed, 76 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 31745006..f4748d6d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -666,6 +666,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: + # after running refiner, the refiner model is not unloaded - webui swaps back to main model here + if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint: + sd_models.reload_model_weights() + # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index f6051604..981aa93d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return res +class SkipWritingToConfig: + """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" + + skip = False + previous = None + + def __enter__(self): + self.previous = SkipWritingToConfig.skip + SkipWritingToConfig.skip = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + SkipWritingToConfig.skip = self.previous + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 39586b40..3f3e83e3 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models from modules.shared import opts, state SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -127,3 +127,20 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() + + +def apply_refiner(sampler): + completed_ratio = sampler.step / sampler.steps + if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint: + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + + sampler.update_inner_model() + + sampler.p.setup_conds() diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 4a8396f9..5df926d3 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -19,7 +19,8 @@ samplers_data_compvis = [ class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) + self.p = None + self.sampler = constructor(shared.sd_model) self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') self.is_plms = hasattr(self.sampler, 'p_sample_plms') self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) @@ -32,6 +33,7 @@ class VanillaStableDiffusionSampler: self.nmask = None self.init_latent = None self.sampler_noises = None + self.steps = None self.step = 0 self.stop_at = None self.eta = None @@ -44,6 +46,7 @@ class VanillaStableDiffusionSampler: return 0 def launch_sampling(self, steps, func): + self.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -61,10 +64,15 @@ class VanillaStableDiffusionSampler: return res + def update_inner_model(self): + self.sampler.model = shared.sd_model + def before_sample(self, x, ts, cond, unconditional_conditioning): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + sd_samplers_common.apply_refiner(self) + if self.stop_at is not None and self.step > self.stop_at: raise sd_samplers_common.InterruptedException @@ -134,6 +142,8 @@ class VanillaStableDiffusionSampler: self.update_step(x) def initialize(self, p): + self.p = p + if self.is_ddim: self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim else: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index db71a549..be1bd35e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra, sd_models from modules.processing import StableDiffusionProcessing from modules.shared import opts, state @@ -87,15 +87,25 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model): + def __init__(self): super().__init__() - self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None + self.steps = None self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False + self.p = None + + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -113,10 +123,15 @@ class CFGDenoiser(torch.nn.Module): return denoised + def update_inner_model(self): + self.model_wrap = None + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + sd_samplers_common.apply_refiner(self) + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 @@ -267,13 +282,13 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.p = None self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.model_wrap_cfg = CFGDenoiser() + self.model_wrap = self.model_wrap_cfg.inner_model self.sampler_noises = None self.stop_at = None self.eta = None @@ -305,6 +320,7 @@ class KDiffusionSampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -324,6 +340,8 @@ class KDiffusionSampler: return p.steps def initialize(self, p: StableDiffusionProcessing): + self.p = p + self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 diff --git a/modules/shared.py b/modules/shared.py index 078e8135..ed8395dc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -461,6 +461,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), + "sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"), + "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { -- cgit v1.2.1 From 956e69bf3a9387b6a6cda6823d729e3d3f13c3e1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:07:08 +0300 Subject: lint! --- modules/sd_samplers_kdiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index be1bd35e..3aee4e3a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra, sd_models +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra from modules.processing import StableDiffusionProcessing from modules.shared import opts, state -- cgit v1.2.1 From 5a0db84b6c7322082c7532df11a29a95a59a612b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:53:33 +0300 Subject: add infotext add proper support for recalculating conds in k-diffusion samplers remove support for compvis samplers --- modules/generation_parameters_copypaste.py | 2 ++ modules/processing.py | 10 ++++++++++ modules/sd_samplers_common.py | 29 ++++++++++++++++++++--------- modules/sd_samplers_compvis.py | 2 -- modules/sd_samplers_kdiffusion.py | 24 ++++++++++++++++-------- 5 files changed, 48 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index e71c9601..6711ca16 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [ ('Pad conds', 'pad_cond_uncond'), ('VAE Encoder', 'sd_vae_encode_method'), ('VAE Decoder', 'sd_vae_decode_method'), + ('Refiner', 'sd_refiner_checkpoint'), + ('Refiner switch at', 'sd_refiner_switch_at'), ] diff --git a/modules/processing.py b/modules/processing.py index f4748d6d..ec66fd8e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -370,6 +370,9 @@ class StableDiffusionProcessing: self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + def get_conds(self): + return self.c, self.uc + def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -1251,6 +1254,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.extra_network_data) + def get_conds(self): + if self.is_hr_pass: + return self.hr_c, self.hr_uc + + return super().get_conds() + + def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 3f3e83e3..92bf0ca1 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -131,16 +131,27 @@ replace_torchsde_browinan() def apply_refiner(sampler): completed_ratio = sampler.step / sampler.steps - if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint: - refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) - if refiner_checkpoint_info is None: - raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') - with sd_models.SkipWritingToConfig(): - sd_models.reload_model_weights(info=refiner_checkpoint_info) + if completed_ratio <= shared.opts.sd_refiner_switch_at: + return False + + if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: + return False + + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') + + sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + sampler.p.setup_conds() + sampler.update_inner_model() - devices.torch_gc() + return True - sampler.update_inner_model() - sampler.p.setup_conds() diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 5df926d3..2eeec18a 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -71,8 +71,6 @@ class VanillaStableDiffusionSampler: if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - sd_samplers_common.apply_refiner(self) - if self.stop_at is not None and self.step > self.stop_at: raise sd_samplers_common.InterruptedException diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3aee4e3a..46da0a97 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -87,8 +87,9 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self): + def __init__(self, sampler): super().__init__() + self.sampler = sampler self.model_wrap = None self.mask = None self.nmask = None @@ -126,11 +127,17 @@ class CFGDenoiser(torch.nn.Module): def update_inner_model(self): self.model_wrap = None + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - sd_samplers_common.apply_refiner(self) + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. @@ -282,12 +289,12 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - self.p = None self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser() + self.sampler_extra_args = {} + self.model_wrap_cfg = CFGDenoiser(self) self.model_wrap = self.model_wrap_cfg.inner_model self.sampler_noises = None self.stop_at = None @@ -476,7 +483,7 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -484,7 +491,7 @@ class KDiffusionSampler: 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -514,13 +521,14 @@ class KDiffusionSampler: extra_params_kwargs['noise_sampler'] = noise_sampler self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True -- cgit v1.2.1 From f8ff8c0638997fd0aef217db1505598846f14782 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:09:40 +0300 Subject: merge errors --- modules/sd_samplers_cfg_denoiser.py | 23 +++++++++++++++++++++-- modules/sd_samplers_common.py | 6 +++++- modules/sd_samplers_kdiffusion.py | 17 ++++++++++++----- modules/sd_samplers_timesteps.py | 27 +++++++++++++++++---------- 4 files changed, 55 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index d826222c..a532e013 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -38,16 +38,24 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model, sampler): + def __init__(self, sampler): super().__init__() - self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None + self.steps = None self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False self.sampler = sampler + self.model_wrap = None + self.p = None + + @property + def inner_model(self): + raise NotImplementedError() + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -68,10 +76,21 @@ class CFGDenoiser(torch.nn.Module): def get_pred_x0(self, x_in, x_out, sigma): return x_out + def update_inner_model(self): + self.model_wrap = None + + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 15f27970..fa3614ff 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -202,8 +202,9 @@ class Sampler: self.conditioning_key = shared.sd_model.model.conditioning_key - self.model_wrap = None + self.p = None self.model_wrap_cfg = None + self.sampler_extra_args = None def callback_state(self, d): step = d['i'] @@ -215,6 +216,7 @@ class Sampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -234,6 +236,8 @@ class Sampler: return p.steps def initialize(self, p) -> dict: + self.p = p + self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3ff4b634..95a43cef 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -52,17 +52,24 @@ k_diffusion_scheduler = { } +class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap + + class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): - super().__init__(funcname) - self.extra_params = sampler_extra_params.get(funcname, []) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserKDiffusion(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index d89d0efb..965e61c6 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -44,10 +44,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): class CFGDenoiserTimesteps(CFGDenoiser): - def __init__(self, model, sampler): - super().__init__(model, sampler) + def __init__(self, sampler): + super().__init__(sampler) - self.alphas = model.inner_model.alphas_cumprod + self.alphas = shared.sd_model.alphas_cumprod def get_pred_x0(self, x_in, x_out, sigma): ts = int(sigma.item()) @@ -60,6 +60,14 @@ class CFGDenoiserTimesteps(CFGDenoiser): return pred_x0 + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + class CompVisSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): @@ -68,9 +76,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_option_field = 'eta_ddim' self.eta_infotext_field = 'Eta DDIM' - denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser - self.model_wrap = denoiser(sd_model) - self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserTimesteps(self) def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -106,7 +112,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -114,7 +120,7 @@ class CompVisSampler(sd_samplers_common.Sampler): 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -132,13 +138,14 @@ class CompVisSampler(sd_samplers_common.Sampler): extra_params_kwargs['timesteps'] = timesteps self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True -- cgit v1.2.1 From ec194b637476855ea5918a44a65e85fb587483ab Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:14:02 +0300 Subject: fix webui not switching back to original model from refiner when batch count is greater than 1 --- modules/processing.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index b635cc74..cf62cdd3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -766,6 +766,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break + sd_models.reload_model_weights() # model can be changed for example by refiner + p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] -- cgit v1.2.1 From 1aefb5025929818b2a96cbb6148fcc2db7b947ec Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:17:25 +0300 Subject: add None refiner option --- modules/sd_samplers_common.py | 3 +++ modules/shared.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index fa3614ff..b6ad6830 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -137,6 +137,9 @@ def apply_refiner(sampler): if completed_ratio <= shared.opts.sd_refiner_switch_at: return False + if shared.opts.sd_refiner_checkpoint == "None": + return False + if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: return False diff --git a/modules/shared.py b/modules/shared.py index 2fd29904..9935d2a7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -462,7 +462,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), - "sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"), + "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"), "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) -- cgit v1.2.1 From ac8a5d18d3ede6bcb8fa5a3da1c7c28e064cd65d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 10 Aug 2023 17:04:59 +0300 Subject: resolve merge issues --- modules/sd_models.py | 7 +++++-- modules/sd_samplers_kdiffusion.py | 3 +-- modules/shared_options.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6cb2f34..a178adca 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -640,8 +640,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): timer.record("send model to device") model_data.set_sd_model(already_loaded) - shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title - shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title + shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") return model_data.sd_model elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e1854980..95a43cef 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,8 +1,7 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra -from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.shared import opts import modules.shared as shared diff --git a/modules/shared_options.py b/modules/shared_options.py index 9ae51f18..1e5b64ea 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -140,6 +140,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), + "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"), + "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { -- cgit v1.2.1