aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-05-17 23:18:56 +0300
committerGitHub <noreply@github.com>2023-05-17 23:18:56 +0300
commit04b4508a66de58c9f3a422fdcad4dd2ec3ad39ce (patch)
tree1cb30a63099a69f678d4901b495203b765a6df59 /modules/sd_samplers_kdiffusion.py
parent7201d940a4fe664beb9662fadbeade4ee1d788f7 (diff)
parentb397f63e00bbfbe9087d80abb457aa9a593b181b (diff)
Merge branch 'dev' into improve-frontend-responsiveness
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py98
1 files changed, 61 insertions, 37 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index e9f08518..552c6c64 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,7 +1,6 @@
from collections import deque
import torch
import inspect
-import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
@@ -9,25 +8,26 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
]
samplers_data_k_diffusion = [
@@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module):
return denoised
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module):
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -115,12 +115,21 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
+ skip_uncond = False
- if tensor.shape[1] == uncond.shape[1]:
- if not is_edit_model:
- cond_in = torch.cat([tensor, uncond])
- else:
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
@@ -144,28 +153,39 @@ class CFGDenoiser(torch.nn.Module):
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
+ sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
- if not is_edit_model:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- else:
+ if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
- self.step += 1
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+ self.step += 1
return denoised
@@ -182,7 +202,7 @@ class TorchHijack:
if hasattr(torch, item):
return getattr(torch, item)
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
if self.sampler_noises:
@@ -190,7 +210,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
- if x.device.type == 'mps':
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
@@ -210,6 +230,7 @@ class KDiffusionSampler:
self.eta = None
self.config = None
self.last_latent = None
+ self.s_min_uncond = None
self.conditioning_key = sd_model.model.conditioning_key
@@ -244,6 +265,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -299,7 +321,7 @@ class KDiffusionSampler:
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
-
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
@@ -322,10 +344,11 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
@@ -356,10 +379,11 @@ class KDiffusionSampler:
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples