aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r--modules/sd_samplers.py56
1 files changed, 32 insertions, 24 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 4fe67854..2ca17d8b 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,4 +1,4 @@
-from collections import namedtuple
+from collections import namedtuple, deque
import numpy as np
from math import floor
import torch
@@ -26,6 +26,7 @@ samplers_k_diffusion = [
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
@@ -33,6 +34,7 @@ samplers_k_diffusion = [
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
@@ -50,6 +52,7 @@ all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
+samplers_map = {}
def create_sampler(name, model):
@@ -75,6 +78,12 @@ def set_samplers():
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
+ samplers_map.clear()
+ for sampler in all_samplers:
+ samplers_map[sampler.name.lower()] = sampler.name
+ for alias in sampler.aliases:
+ samplers_map[alias.lower()] = sampler.name
+
set_samplers()
@@ -127,7 +136,8 @@ class InterruptedException(BaseException):
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
+ self.is_plms = hasattr(self.sampler, 'p_sample_plms')
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
@@ -218,7 +228,6 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
@@ -227,7 +236,6 @@ class VanillaStableDiffusionSampler:
return num_steps
-
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
@@ -260,9 +268,10 @@ class VanillaStableDiffusionSampler:
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
+ # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
@@ -335,28 +344,39 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack:
- def __init__(self, kdiff_sampler):
- self.kdiff_sampler = kdiff_sampler
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item):
if item == 'randn_like':
- return self.kdiff_sampler.randn_like
+ return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+
+ return torch.randn_like(x)
+
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
- self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
- self.sampler_noise_index = 0
self.stop_at = None
self.eta = None
self.default_eta = 1.0
@@ -389,26 +409,14 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p):
return p.steps
- def randn_like(self, x):
- noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
-
- if noise is not None and x.shape == noise.shape:
- res = noise
- else:
- res = torch.randn_like(x)
-
- self.sampler_noise_index += 1
- return res
-
def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0
- self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral
if self.sampler_noises is not None:
- k_diffusion.sampling.torch = TorchHijack(self)
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
extra_params_kwargs = {}
for param_name in self.extra_params: