aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-10 17:05:32 +0300
committerGitHub <noreply@github.com>2023-08-10 17:05:32 +0300
commit36762f0eaf04c270dde23849cb198446ecdc4100 (patch)
tree879b63e94d986f8d4fb30d65ee5aa4ae45f3e640 /modules
parent959404e0e29531d24f2e02088bf0399f4b9db15b (diff)
parentac8a5d18d3ede6bcb8fa5a3da1c7c28e064cd65d (diff)
Merge pull request #12371 from AUTOMATIC1111/refiner
initial refiner support
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py16
-rw-r--r--modules/sd_models.py25
-rw-r--r--modules/sd_samplers_cfg_denoiser.py23
-rw-r--r--modules/sd_samplers_common.py37
-rw-r--r--modules/sd_samplers_compvis.py0
-rw-r--r--modules/sd_samplers_kdiffusion.py29
-rw-r--r--modules/sd_samplers_timesteps.py27
-rw-r--r--modules/shared_options.py2
8 files changed, 131 insertions, 28 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 44d47e8c..efa6eafa 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -377,6 +377,9 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ def get_conds(self):
+ return self.c, self.uc
+
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@@ -611,6 +614,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
+ # after running refiner, the refiner model is not unloaded - webui swaps back to main model here
+ if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
+ sd_models.reload_model_weights()
+
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
@@ -710,6 +717,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
+ sd_models.reload_model_weights() # model can be changed for example by refiner
+
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
@@ -1201,6 +1210,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
extra_networks.activate(self, self.extra_network_data)
+ def get_conds(self):
+ if self.is_hr_pass:
+ return self.hr_c, self.hr_uc
+
+ return super().get_conds()
+
+
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7a866a07..a178adca 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -295,11 +295,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return res
+class SkipWritingToConfig:
+ """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
+
+ skip = False
+ previous = None
+
+ def __enter__(self):
+ self.previous = SkipWritingToConfig.skip
+ SkipWritingToConfig.skip = True
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ SkipWritingToConfig.skip = self.previous
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
@@ -624,8 +640,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
timer.record("send model to device")
model_data.set_sd_model(already_loaded)
- shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
- shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
+
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
+ shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
+
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index d826222c..a532e013 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -38,16 +38,24 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""
- def __init__(self, model, sampler):
+ def __init__(self, sampler):
super().__init__()
- self.inner_model = model
+ self.model_wrap = None
self.mask = None
self.nmask = None
self.init_latent = None
+ self.steps = None
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.sampler = sampler
+ self.model_wrap = None
+ self.p = None
+
+ @property
+ def inner_model(self):
+ raise NotImplementedError()
+
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -68,10 +76,21 @@ class CFGDenoiser(torch.nn.Module):
def get_pred_x0(self, x_in, x_out, sigma):
return x_out
+ def update_inner_model(self):
+ self.model_wrap = None
+
+ c, uc = self.p.get_conds()
+ self.sampler.sampler_extra_args['cond'] = c
+ self.sampler.sampler_extra_args['uncond'] = uc
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
+ if sd_samplers_common.apply_refiner(self):
+ cond = self.sampler.sampler_extra_args['cond']
+ uncond = self.sampler.sampler_extra_args['uncond']
+
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 97bc0804..35c4d657 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -3,7 +3,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling
@@ -131,6 +131,35 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
+def apply_refiner(sampler):
+ completed_ratio = sampler.step / sampler.steps
+
+ if completed_ratio <= shared.opts.sd_refiner_switch_at:
+ return False
+
+ if shared.opts.sd_refiner_checkpoint == "None":
+ return False
+
+ if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
+ return False
+
+ refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
+ if refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
+
+ sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
+
+ with sd_models.SkipWritingToConfig():
+ sd_models.reload_model_weights(info=refiner_checkpoint_info)
+
+ devices.torch_gc()
+ sampler.p.setup_conds()
+ sampler.update_inner_model()
+
+ return True
+
+
class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.
@@ -176,8 +205,9 @@ class Sampler:
self.conditioning_key = shared.sd_model.model.conditioning_key
- self.model_wrap = None
+ self.p = None
self.model_wrap_cfg = None
+ self.sampler_extra_args = None
def callback_state(self, d):
step = d['i']
@@ -189,6 +219,7 @@ class Sampler:
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
+ self.model_wrap_cfg.steps = steps
state.sampling_steps = steps
state.sampling_step = 0
@@ -208,6 +239,8 @@ class Sampler:
return p.steps
def initialize(self, p) -> dict:
+ self.p = p
+ self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/modules/sd_samplers_compvis.py
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 5613b8c1..95a43cef 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,8 +1,7 @@
import torch
import inspect
import k_diffusion.sampling
-from modules import sd_samplers_common, sd_samplers_extra
-from modules.sd_samplers_cfg_denoiser import CFGDenoiser
+from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
from modules.shared import opts
import modules.shared as shared
@@ -53,17 +52,24 @@ k_diffusion_scheduler = {
}
+class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
+ @property
+ def inner_model(self):
+ if self.model_wrap is None:
+ denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+ self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
+
+ return self.model_wrap
+
+
class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model):
-
super().__init__(funcname)
- self.extra_params = sampler_extra_params.get(funcname, [])
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
- denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
- self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
- self.model_wrap_cfg = CFGDenoiser(self.model_wrap, self)
+ self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -144,7 +150,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- extra_args = {
+ self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
@@ -152,7 +158,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
's_min_uncond': self.s_min_uncond
}
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
@@ -184,13 +190,14 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
extra_params_kwargs['noise_sampler'] = noise_sampler
self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ }
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index f61799a8..16572c7e 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -45,10 +45,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
class CFGDenoiserTimesteps(CFGDenoiser):
- def __init__(self, model, sampler):
- super().__init__(model, sampler)
+ def __init__(self, sampler):
+ super().__init__(sampler)
- self.alphas = model.inner_model.alphas_cumprod
+ self.alphas = shared.sd_model.alphas_cumprod
def get_pred_x0(self, x_in, x_out, sigma):
ts = int(sigma.item())
@@ -61,6 +61,14 @@ class CFGDenoiserTimesteps(CFGDenoiser):
return pred_x0
+ @property
+ def inner_model(self):
+ if self.model_wrap is None:
+ denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
+ self.model_wrap = denoiser(shared.sd_model)
+
+ return self.model_wrap
+
class CompVisSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model):
@@ -69,9 +77,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_option_field = 'eta_ddim'
self.eta_infotext_field = 'Eta DDIM'
- denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
- self.model_wrap = denoiser(sd_model)
- self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
+ self.model_wrap_cfg = CFGDenoiserTimesteps(self)
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -107,7 +113,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- extra_args = {
+ self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
@@ -115,7 +121,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
's_min_uncond': self.s_min_uncond
}
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
@@ -133,13 +139,14 @@ class CompVisSampler(sd_samplers_common.Sampler):
extra_params_kwargs['timesteps'] = timesteps
self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ }
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 9ae51f18..1e5b64ea 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -140,6 +140,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
+ "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"),
+ "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {