aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py60
1 files changed, 33 insertions, 27 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 0fc9f456..59982fc9 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,7 +1,6 @@
from collections import deque
import torch
import inspect
-import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
@@ -9,25 +8,28 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True, 'discard_next_to_last_sigma': True}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True, 'discard_next_to_last_sigma': True}),
]
samplers_data_k_diffusion = [
@@ -87,17 +89,17 @@ class CFGDenoiser(torch.nn.Module):
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -161,7 +163,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -181,6 +183,10 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised
@@ -224,7 +230,7 @@ class KDiffusionSampler:
self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None
+ self.config = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
@@ -317,7 +323,7 @@ class KDiffusionSampler:
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
-
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
@@ -333,16 +339,16 @@ class KDiffusionSampler:
if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigma_sched
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ extra_args = {
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}
@@ -369,15 +375,15 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))