aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r--modules/sd_samplers.py76
1 files changed, 49 insertions, 27 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 4c123d3b..e904d860 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -9,7 +9,7 @@ import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing, images
+from modules import prompt_parser, devices, processing, images, sd_vae_approx
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -23,16 +23,16 @@ samplers_k_diffusion = [
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
@@ -106,20 +106,32 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def single_sample_to_image(sample):
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
-def sample_to_image(samples, index=0):
- return single_sample_to_image(samples[index])
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
-def samples_to_image_grid(samples):
- return images.image_grid([single_sample_to_image(sample) for sample in samples])
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def store_latent(decoded):
@@ -288,6 +300,16 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@@ -329,12 +351,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -444,16 +461,26 @@ class KDiffusionSampler:
return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = setup_img2img_steps(p, steps)
-
+ def get_sigmas(self, p, steps):
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
+ if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+ return sigmas
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = setup_img2img_steps(p, steps)
+
+ sigmas = self.get_sigmas(p, steps)
+
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
@@ -485,12 +512,7 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
- if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
- elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
- else:
- sigmas = self.model_wrap.get_sigmas(steps)
+ sigmas = self.get_sigmas(p, steps)
x = x * sigmas[0]