aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_cfg_denoiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_cfg_denoiser.py')
-rw-r--r--modules/sd_samplers_cfg_denoiser.py95
1 files changed, 82 insertions, 13 deletions
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index b8101d38..a73d3b03 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -53,9 +53,13 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
+ self.padded_cond_uncond_v0 = False
self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -88,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
+ def pad_cond_uncond(self, cond, uncond):
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ cond = pad_cond(cond, -num_repeats, empty)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = pad_cond(uncond, num_repeats, empty)
+ self.padded_cond_uncond = True
+
+ return cond, uncond
+
+ def pad_cond_uncond_v0(self, cond, uncond):
+ """
+ Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
+
+ If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
+ If 'uncond' is a tensor, it is padded directly.
+
+ If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
+ is repeated to match the number of columns in 'cond'.
+
+ If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
+ to match the number of columns in 'cond'.
+
+ Args:
+ cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
+ uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
+
+ Returns:
+ tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
+
+ Note:
+ This is the padding that was always used in DDIM before version 1.6.0
+ """
+
+ is_dict_cond = isinstance(uncond, dict)
+ uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
+
+ if uncond_vec.shape[1] < cond.shape[1]:
+ last_vector = uncond_vec[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
+ uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
+ self.padded_cond_uncond_v0 = True
+ elif uncond_vec.shape[1] > cond.shape[1]:
+ uncond_vec = uncond_vec[:, :cond.shape[1]]
+ self.padded_cond_uncond_v0 = True
+
+ if is_dict_cond:
+ uncond['crossattn'] = uncond_vec
+ else:
+ uncond = uncond_vec
+
+ return cond, uncond
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -105,8 +165,21 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # If we use masks, blending between the denoised and original latent images occurs here.
+ def apply_blend(current_latent):
+ blended_latent = current_latent * self.nmask + self.init_latent * self.mask
+
+ if self.p.scripts is not None:
+ from modules import scripts
+ mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
+ self.p.scripts.on_mask_blend(self.p, mba)
+ blended_latent = mba.blended_latent
+
+ return blended_latent
+
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ x = apply_blend(x)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -130,7 +203,7 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
@@ -146,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
- if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
- empty = shared.sd_model.cond_stage_model_empty_prompt
- num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
-
- if num_repeats < 0:
- tensor = pad_cond(tensor, -num_repeats, empty)
- self.padded_cond_uncond = True
- elif num_repeats > 0:
- uncond = pad_cond(uncond, num_repeats, empty)
- self.padded_cond_uncond = True
+ self.padded_cond_uncond_v0 = False
+ if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
+ tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
+ elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ tensor, uncond = self.pad_cond_uncond(tensor, uncond)
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
@@ -207,8 +275,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)