aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_cfg_denoiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_cfg_denoiser.py')
-rw-r--r--modules/sd_samplers_cfg_denoiser.py75
1 files changed, 73 insertions, 2 deletions
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index b8101d38..efbe7a40 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -43,6 +43,9 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None
self.mask = None
self.nmask = None
+ self.mask_blend_power = 1
+ self.mask_blend_scale = 0.5
+ self.inpaint_detail_preservation = 4
self.init_latent = None
self.steps = None
"""number of steps as specified by user in UI"""
@@ -56,6 +59,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -89,6 +95,69 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+ def latent_blend(a, b, t):
+ """
+ Interpolates two latent image representations according to the parameter t,
+ where the interpolated vectors' magnitudes are also interpolated separately.
+ The "detail_preservation" factor biases the magnitude interpolation towards
+ the larger of the two magnitudes.
+ """
+ # NOTE: We use inplace operations wherever possible.
+
+ one_minus_t = 1 - t
+
+ # Linearly interpolate the image vectors.
+ a_scaled = a * one_minus_t
+ b_scaled = b * t
+ image_interp = a_scaled
+ image_interp.add_(b_scaled)
+ result_type = image_interp.dtype
+ del a_scaled, b_scaled
+
+ # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
+ # 64-bit operations are used here to allow large exponents.
+ current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
+
+ # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
+ a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t
+ b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t
+ desired_magnitude = a_magnitude
+ desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation)
+ del a_magnitude, b_magnitude, one_minus_t
+
+ # Change the linearly interpolated image vectors' magnitudes to the value we want.
+ # This is the last 64-bit operation.
+ image_interp_scaling_factor = desired_magnitude
+ image_interp_scaling_factor.div_(current_magnitude)
+ image_interp_scaled = image_interp
+ image_interp_scaled.mul_(image_interp_scaling_factor)
+ del current_magnitude
+ del desired_magnitude
+ del image_interp
+ del image_interp_scaling_factor
+
+ image_interp_scaled = image_interp_scaled.to(result_type)
+ del result_type
+
+ return image_interp_scaled
+
+ def get_modified_nmask(nmask, _sigma):
+ """
+ Converts a negative mask representing the transparency of the original latent vectors being overlayed
+ to a mask that is scaled according to the denoising strength for this step.
+
+ Where:
+ 0 = fully opaque, infinite density, fully masked
+ 1 = fully transparent, zero density, fully unmasked
+
+ We bring this transparency to a power, as this allows one to simulate N number of blending operations
+ where N can be any positive real value. Using this one can control the balance of influence between
+ the denoiser and the original latents according to the sigma value.
+
+ NOTE: "mask" is not used
+ """
+ return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale)
+
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -105,8 +174,9 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma))
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +277,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma))
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)