aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorrandom_thoughtss <random_thoughtss@proton.me>2022-10-29 10:35:51 -0700
committerrandom_thoughtss <random_thoughtss@proton.me>2022-10-29 10:35:51 -0700
commit6e2ce4e735db64afcd0fe637327ca4ec78335706 (patch)
tree428309a1e52a5dfd2d7ce9f70652cb4cdfee9bab /modules/processing.py
parent44ab954fabb9c1273366ebdca47f8da394d61aab (diff)
Added image conditioning to latent upscale.
Only comuted if the mask weight is not 1.0 to avoid extra memory. Also includes some code cleanup.
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py29
1 files changed, 11 insertions, 18 deletions
diff --git a/modules/processing.py b/modules/processing.py
index f18b7db2..ee0e9e34 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return torch.zeros(
- x.shape[0], 5, 1, 1,
- dtype=x.dtype,
- device=x.device
- )
+ return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
- return torch.zeros(
- latent_image.shape[0], 5, 1, 1,
- dtype=latent_image.dtype,
- device=latent_image.device
- )
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
@@ -174,11 +166,10 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
- conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(source_image.device)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -653,7 +644,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
- image_conditioning = self.txt2img_image_conditioning(samples)
+
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
@@ -675,11 +672,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
- image_conditioning = self.img2img_image_conditioning(
- decoded_samples,
- samples,
- decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
- )
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()