aboutsummaryrefslogtreecommitdiff
path: root/modules/soft_inpainting.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/soft_inpainting.py')
-rw-r--r--modules/soft_inpainting.py66
1 files changed, 56 insertions, 10 deletions
diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py
index 56a87774..b36ac8fa 100644
--- a/modules/soft_inpainting.py
+++ b/modules/soft_inpainting.py
@@ -25,26 +25,32 @@ def latent_blend(soft_inpainting, a, b, t):
# NOTE: We use inplace operations wherever possible.
- one_minus_t = 1 - t
+ # [4][w][h] to [1][4][w][h]
+ t2 = t.unsqueeze(0)
+ # [4][w][h] to [1][1][w][h] - the [4] seem redundant.
+ t3 = t[0].unsqueeze(0).unsqueeze(0)
+
+ one_minus_t2 = 1 - t2
+ one_minus_t3 = 1 - t3
# Linearly interpolate the image vectors.
- a_scaled = a * one_minus_t
- b_scaled = b * t
+ a_scaled = a * one_minus_t2
+ b_scaled = b * t2
image_interp = a_scaled
image_interp.add_(b_scaled)
result_type = image_interp.dtype
- del a_scaled, b_scaled
+ del a_scaled, b_scaled, t2, one_minus_t2
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
- current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
+ current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
- a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t
- b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t
+ a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3
+ b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3
desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
- del a_magnitude, b_magnitude, one_minus_t
+ del a_magnitude, b_magnitude, t3, one_minus_t3
# Change the linearly interpolated image vectors' magnitudes to the value we want.
# This is the last 64-bit operation.
@@ -78,10 +84,11 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
NOTE: "mask" is not used
"""
import torch
- return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
+ # todo: Why is sigma 2D? Both values are the same.
+ return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
-def generate_adaptive_masks(
+def apply_adaptive_masks(
latent_orig,
latent_processed,
overlay_images,
@@ -142,6 +149,45 @@ def generate_adaptive_masks(
overlay_images[i] = image_masked.convert('RGBA')
+def apply_masks(
+ soft_inpainting,
+ nmask,
+ overlay_images,
+ masks_for_overlay,
+ width, height,
+ paste_to):
+ import torch
+ import numpy as np
+ import modules.processing as proc
+ import modules.images as images
+ from PIL import Image, ImageOps, ImageFilter
+
+ converted_mask = nmask[0].float()
+ converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2)
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, width, height)
+ converted_mask = proc.create_binary_mask(converted_mask, round=False)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if paste_to is not None:
+ converted_mask = proc.uncrop(converted_mask,
+ (width, height),
+ paste_to)
+
+ for i, overlay_image in enumerate(overlay_images):
+ masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ overlay_images[i] = image_masked.convert('RGBA')
+
# ------------------- Constants -------------------