aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_cfg_denoiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_cfg_denoiser.py')
-rw-r--r--modules/sd_samplers_cfg_denoiser.py41
1 files changed, 28 insertions, 13 deletions
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index ceb612d7..efbe7a40 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -102,29 +102,44 @@ class CFGDenoiser(torch.nn.Module):
The "detail_preservation" factor biases the magnitude interpolation towards
the larger of the two magnitudes.
"""
- # Record the original latent vector magnitudes.
- # We bring them to a power so that larger magnitudes are favored over smaller ones.
- # 64-bit operations are used here to allow large exponents.
- a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation
- b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation
+ # NOTE: We use inplace operations wherever possible.
one_minus_t = 1 - t
- # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
- interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation)
-
# Linearly interpolate the image vectors.
- image_interp = a * one_minus_t + b * t
+ a_scaled = a * one_minus_t
+ b_scaled = b * t
+ image_interp = a_scaled
+ image_interp.add_(b_scaled)
+ result_type = image_interp.dtype
+ del a_scaled, b_scaled
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
- image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001
+ current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
+
+ # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
+ a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t
+ b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t
+ desired_magnitude = a_magnitude
+ desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation)
+ del a_magnitude, b_magnitude, one_minus_t
# Change the linearly interpolated image vectors' magnitudes to the value we want.
# This is the last 64-bit operation.
- image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype)
-
- return image_interp
+ image_interp_scaling_factor = desired_magnitude
+ image_interp_scaling_factor.div_(current_magnitude)
+ image_interp_scaled = image_interp
+ image_interp_scaled.mul_(image_interp_scaling_factor)
+ del current_magnitude
+ del desired_magnitude
+ del image_interp
+ del image_interp_scaling_factor
+
+ image_interp_scaled = image_interp_scaled.to(result_type)
+ del result_type
+
+ return image_interp_scaled
def get_modified_nmask(nmask, _sigma):
"""