aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorcatalpaaa <89681913+catalpaaa@users.noreply.github.com>2023-05-01 11:59:21 -0700
committerGitHub <noreply@github.com>2023-05-01 11:59:21 -0700
commit9eb5b3e90f1775af81c828f19e7caded70ba8884 (patch)
tree4630848311b90277462842669274852f7f6a972a /modules/sd_samplers_kdiffusion.py
parentecdc6471e7d694ef9ecec96e8c3128237efe069a (diff)
parent72cd27a13587c9579942577e9e3880778be195f6 (diff)
Merge branch 'experimental' into subpath-support
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py46
1 files changed, 33 insertions, 13 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index e9f08518..eb98e599 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module):
return denoised
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -115,12 +115,21 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
+ skip_uncond = False
- if tensor.shape[1] == uncond.shape[1]:
- if not is_edit_model:
- cond_in = torch.cat([tensor, uncond])
- else:
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
@@ -144,7 +153,13 @@ class CFGDenoiser(torch.nn.Module):
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
cfg_denoised_callback(denoised_params)
@@ -152,20 +167,21 @@ class CFGDenoiser(torch.nn.Module):
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
+ sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
- if not is_edit_model:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- else:
+ if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.step += 1
-
return denoised
@@ -190,7 +206,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
- if x.device.type == 'mps':
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
@@ -210,6 +226,7 @@ class KDiffusionSampler:
self.eta = None
self.config = None
self.last_latent = None
+ self.s_min_uncond = None
self.conditioning_key = sd_model.model.conditioning_key
@@ -244,6 +261,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -326,6 +344,7 @@ class KDiffusionSampler:
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
@@ -359,7 +378,8 @@ class KDiffusionSampler:
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples