aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-08 19:20:11 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-08 21:04:44 +0300
commit8285a149d8c488ae6c7a566eb85fb5e825145464 (patch)
tree89f204c4dc44f58ac2f15719f9bdbf28c4590cb5
parent2d8e4a654480ea080fec62834331a3c632ed0330 (diff)
add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them)
-rw-r--r--modules/sd_samplers.py3
-rw-r--r--modules/sd_samplers_cfg_denoiser.py50
-rw-r--r--modules/sd_samplers_common.py140
-rw-r--r--modules/sd_samplers_kdiffusion.py152
-rw-r--r--modules/sd_samplers_timesteps.py147
-rw-r--r--modules/sd_samplers_timesteps_impl.py135
6 files changed, 455 insertions, 172 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index bea2684c..fe206894 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,4 +1,4 @@
-from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
+from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
@@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_compvis.samplers_data_compvis,
+ *sd_samplers_timesteps.samplers_data_timesteps,
]
all_samplers_map = {x.name: x for x in all_samplers}
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index 33a49783..166a00c7 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""
- def __init__(self, model):
+ def __init__(self, model, sampler):
super().__init__()
self.inner_model = model
self.mask = None
@@ -48,6 +48,7 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
+ self.sampler = sampler
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -65,6 +66,9 @@ class CFGDenoiser(torch.nn.Module):
return denoised
+ def get_pred_x0(self, x_in, x_out, sigma):
+ return x_out
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -78,6 +82,9 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ if self.mask is not None:
+ x = self.init_latent * self.mask + self.nmask * x
+
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -170,11 +177,6 @@ class CFGDenoiser(torch.nn.Module):
devices.test_for_nans(x_out, "unet")
- if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
- elif opts.live_preview_content == "Negative prompt":
- sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
-
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
@@ -182,8 +184,16 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- if self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
+
+ if opts.live_preview_content == "Prompt":
+ preview = self.sampler.last_latent
+ elif opts.live_preview_content == "Negative prompt":
+ preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
+ else:
+ preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
+
+ sd_samplers_common.store_latent(preview)
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
@@ -192,27 +202,3 @@ class CFGDenoiser(torch.nn.Module):
self.step += 1
return denoised
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- return devices.randn_like(x)
-
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 39586b40..adda963b 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -1,9 +1,11 @@
-from collections import namedtuple
+import inspect
+from collections import namedtuple, deque
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
+import k_diffusion.sampling
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -127,3 +129,139 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
+
+
+class TorchHijack:
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.randn_like
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
+
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+
+ return devices.randn_like(x)
+
+
+class Sampler:
+ def __init__(self, funcname):
+ self.funcname = funcname
+ self.func = funcname
+ self.extra_params = []
+ self.sampler_noises = None
+ self.stop_at = None
+ self.eta = None
+ self.config = None # set by the function calling the constructor
+ self.last_latent = None
+ self.s_min_uncond = None
+ self.s_churn = 0.0
+ self.s_tmin = 0.0
+ self.s_tmax = float('inf')
+ self.s_noise = 1.0
+
+ self.eta_option_field = 'eta_ancestral'
+ self.eta_infotext_field = 'Eta'
+
+ self.conditioning_key = shared.sd_model.model.conditioning_key
+
+ self.model_wrap = None
+ self.model_wrap_cfg = None
+
+ def callback_state(self, d):
+ step = d['i']
+
+ if self.stop_at is not None and step > self.stop_at:
+ raise InterruptedException
+
+ state.sampling_step = step
+ shared.total_tqdm.update()
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
+ except InterruptedException:
+ return self.last_latent
+
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def initialize(self, p) -> dict:
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.step = 0
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
+ self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
+
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+
+ extra_params_kwargs = {}
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
+
+ if 'eta' in inspect.signature(self.func).parameters:
+ if self.eta != 1.0:
+ p.extra_generation_params[self.eta_infotext_field] = self.eta
+
+ extra_params_kwargs['eta'] = self.eta
+
+ if len(self.extra_params) > 0:
+ s_churn = getattr(opts, 's_churn', p.s_churn)
+ s_tmin = getattr(opts, 's_tmin', p.s_tmin)
+ s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
+ s_noise = getattr(opts, 's_noise', p.s_noise)
+
+ if s_churn != self.s_churn:
+ extra_params_kwargs['s_churn'] = s_churn
+ p.s_churn = s_churn
+ p.extra_generation_params['Sigma churn'] = s_churn
+ if s_tmin != self.s_tmin:
+ extra_params_kwargs['s_tmin'] = s_tmin
+ p.s_tmin = s_tmin
+ p.extra_generation_params['Sigma tmin'] = s_tmin
+ if s_tmax != self.s_tmax:
+ extra_params_kwargs['s_tmax'] = s_tmax
+ p.s_tmax = s_tmax
+ p.extra_generation_params['Sigma tmax'] = s_tmax
+ if s_noise != self.s_noise:
+ extra_params_kwargs['s_noise'] = s_noise
+ p.s_noise = s_noise
+ p.extra_generation_params['Sigma noise'] = s_noise
+
+ return extra_params_kwargs
+
+ def create_noise_sampler(self, x, sigmas, p):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ if shared.opts.no_dpmpp_sde_batch_determinism:
+ return None
+
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+
+
+
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 9c9b46d1..3a2e01b7 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -4,8 +4,7 @@ import inspect
import k_diffusion.sampling
from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
-from modules.processing import StableDiffusionProcessing
-from modules.shared import opts, state
+from modules.shared import opts
import modules.shared as shared
samplers_k_diffusion = [
@@ -54,133 +53,17 @@ k_diffusion_scheduler = {
}
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
+class KDiffusionSampler(sd_samplers_common.Sampler):
+ def __init__(self, funcname, sd_model):
- return devices.randn_like(x)
+ super().__init__(funcname)
+ self.extra_params = sampler_extra_params.get(funcname, [])
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
-class KDiffusionSampler:
- def __init__(self, funcname, sd_model):
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
-
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
- self.funcname = funcname
- self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
- self.extra_params = sampler_extra_params.get(funcname, [])
- self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap)
- self.sampler_noises = None
- self.stop_at = None
- self.eta = None
- self.config = None # set by the function calling the constructor
- self.last_latent = None
- self.s_min_uncond = None
-
- # NOTE: These are also defined in the StableDiffusionProcessing class.
- # They should have been here to begin with but we're going to
- # leave that class __init__ signature alone.
- self.s_churn = 0.0
- self.s_tmin = 0.0
- self.s_tmax = float('inf')
- self.s_noise = 1.0
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def callback_state(self, d):
- step = d['i']
- latent = d["denoised"]
- if opts.live_preview_content == "Combined":
- sd_samplers_common.store_latent(latent)
- self.last_latent = latent
-
- if self.stop_at is not None and step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- state.sampling_step = step
- shared.total_tqdm.update()
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except RecursionError:
- print(
- 'Encountered RecursionError during sampling, returning last latent. '
- 'rho >5 with a polyexponential scheduler may cause this error. '
- 'You should try to use a smaller rho value instead.'
- )
- return self.last_latent
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def number_of_needed_noises(self, p):
- return p.steps
-
- def initialize(self, p: StableDiffusionProcessing):
- self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
- self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.step = 0
- self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
- self.eta = p.eta if p.eta is not None else opts.eta_ancestral
- self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
-
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
-
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
-
- if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
- p.extra_generation_params["Eta"] = self.eta
-
- extra_params_kwargs['eta'] = self.eta
-
- if len(self.extra_params) > 0:
- s_churn = getattr(opts, 's_churn', p.s_churn)
- s_tmin = getattr(opts, 's_tmin', p.s_tmin)
- s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
- s_noise = getattr(opts, 's_noise', p.s_noise)
-
- if s_churn != self.s_churn:
- extra_params_kwargs['s_churn'] = s_churn
- p.s_churn = s_churn
- p.extra_generation_params['Sigma churn'] = s_churn
- if s_tmin != self.s_tmin:
- extra_params_kwargs['s_tmin'] = s_tmin
- p.s_tmin = s_tmin
- p.extra_generation_params['Sigma tmin'] = s_tmin
- if s_tmax != self.s_tmax:
- extra_params_kwargs['s_tmax'] = s_tmax
- p.s_tmax = s_tmax
- p.extra_generation_params['Sigma tmax'] = s_tmax
- if s_noise != self.s_noise:
- extra_params_kwargs['s_noise'] = s_noise
- p.s_noise = s_noise
- p.extra_generation_params['Sigma noise'] = s_noise
-
- return extra_params_kwargs
+ self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self)
def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -232,22 +115,12 @@ class KDiffusionSampler:
return sigmas
- def create_noise_sampler(self, x, sigmas, p):
- """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
- if shared.opts.no_dpmpp_sde_batch_determinism:
- return None
-
- from k_diffusion.sampling import BrownianTreeNoiseSampler
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
- current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
- return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
-
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
-
sigma_sched = sigmas[steps - t_enc - 1:]
+
xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p)
@@ -296,12 +169,14 @@ class KDiffusionSampler:
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = steps
+
if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in parameters:
- extra_params_kwargs['n'] = steps
- else:
+
+ if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigmas
if self.config.options.get('brownian_noise', False):
@@ -322,3 +197,4 @@ class KDiffusionSampler:
return samples
+
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
new file mode 100644
index 00000000..8560d009
--- /dev/null
+++ b/modules/sd_samplers_timesteps.py
@@ -0,0 +1,147 @@
+import torch
+import inspect
+from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
+from modules.sd_samplers_cfg_denoiser import CFGDenoiser
+
+from modules.shared import opts
+import modules.shared as shared
+
+samplers_timesteps = [
+ ('k_DDIM', sd_samplers_timesteps_impl.ddim, ['k_ddim'], {}),
+ ('k_PLMS', sd_samplers_timesteps_impl.plms, ['k_plms'], {}),
+ ('k_UniPC', sd_samplers_timesteps_impl.unipc, ['k_unipc'], {}),
+]
+
+
+samplers_data_timesteps = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_timesteps
+]
+
+
+class CompVisTimestepsDenoiser(torch.nn.Module):
+ def __init__(self, model, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.inner_model = model
+
+ def forward(self, input, timesteps, **kwargs):
+ return self.inner_model.apply_model(input, timesteps, **kwargs)
+
+
+class CompVisTimestepsVDenoiser(torch.nn.Module):
+ def __init__(self, model, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.inner_model = model
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+
+ def forward(self, input, timesteps, **kwargs):
+ model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
+ e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
+ return e_t
+
+
+class CFGDenoiserTimesteps(CFGDenoiser):
+
+ def __init__(self, model, sampler):
+ super().__init__(model, sampler)
+
+ self.alphas = model.inner_model.alphas_cumprod
+
+ def get_pred_x0(self, x_in, x_out, sigma):
+ ts = int(sigma.item())
+
+ s_in = x_in.new_ones([x_in.shape[0]])
+ a_t = self.alphas[ts].item() * s_in
+ sqrt_one_minus_at = (1 - a_t).sqrt()
+
+ pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
+
+ return pred_x0
+
+
+class CompVisSampler(sd_samplers_common.Sampler):
+ def __init__(self, funcname, sd_model):
+ super().__init__(funcname)
+
+ self.eta_option_field = 'eta_ddim'
+ self.eta_infotext_field = 'Eta DDIM'
+
+ denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
+ self.model_wrap = denoiser(sd_model)
+ self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
+
+ def get_timesteps(self, p, steps):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
+ discard_next_to_last_sigma = True
+ p.extra_generation_params["Discard penultimate sigma"] = True
+
+ steps += 1 if discard_next_to_last_sigma else 0
+
+ timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
+
+ return timesteps
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+
+ timesteps = self.get_timesteps(p, steps)
+ timesteps_sched = timesteps[:t_enc]
+
+ alphas_cumprod = shared.sd_model.alphas_cumprod
+ sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
+ sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
+
+ xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'timesteps' in parameters:
+ extra_params_kwargs['timesteps'] = timesteps_sched
+ if 'is_img2img' in parameters:
+ extra_params_kwargs['is_img2img'] = True
+
+ self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
+ extra_args = {
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps = steps or p.steps
+ timesteps = self.get_timesteps(p, steps)
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'timesteps' in parameters:
+ extra_params_kwargs['timesteps'] = timesteps
+
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ return samples
+
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
new file mode 100644
index 00000000..48d7e649
--- /dev/null
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -0,0 +1,135 @@
+import torch
+import tqdm
+import k_diffusion.sampling
+import numpy as np
+
+from modules import shared
+from modules.models.diffusion.uni_pc import uni_pc
+
+
+@torch.no_grad()
+def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+ sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+
+ e_t = model(x, timesteps[index].item() * s_in, **extra_args)
+
+ a_t = alphas[index].item() * s_in
+ a_prev = alphas_prev[index].item() * s_in
+ sigma_t = sigmas[index].item() * s_in
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
+ noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
+ x = a_prev.sqrt() * pred_x0 + dir_xt + noise
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+@torch.no_grad()
+def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ old_eps = []
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = alphas[index].item() * s_in
+ a_prev = alphas_prev[index].item() * s_in
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev).sqrt() * e_t
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev, pred_x0
+
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+ ts = timesteps[index].item() * s_in
+ t_next = timesteps[max(index - 1, 0)].item() * s_in
+
+ e_t = model(x, ts, **extra_args)
+
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = model(x_prev, t_next, **extra_args)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ else:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+
+ x = x_prev
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+class UniPCCFG(uni_pc.UniPC):
+ def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
+ super().__init__(None, *args, **kwargs)
+
+ def after_update(x, model_x):
+ callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
+ self.index += 1
+
+ self.cfg_model = cfg_model
+ self.extra_args = extra_args
+ self.callback = callback
+ self.index = 0
+ self.after_update = after_update
+
+ def get_model_input_time(self, t_continuous):
+ return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
+
+ def model(self, x, t):
+ t_input = self.get_model_input_time(t)
+
+ res = self.cfg_model(x, t_input, **self.extra_args)
+
+ return res
+
+
+def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+
+ ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
+ unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
+ x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
+
+ return x