aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCodeHatchling <steve@codehatch.com>2023-12-02 21:07:02 -0700
committerCodeHatchling <steve@codehatch.com>2023-12-02 21:07:02 -0700
commit73ab982d1b7394574d1cf2e0a151bc457eeed769 (patch)
tree9faf742504d9b7cf31bd3548388c2bf972ddc8ad
parent609dea36ea919aa7db42fd4233c416a45c74578b (diff)
Blend masks are now produced afterward, based on an estimate of the visual difference between the original and modified latent images. This should remove ghosting and clipping artifacts from masks, while preserving the details of largely unchanged content.
-rw-r--r--modules/processing.py119
1 files changed, 90 insertions, 29 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 92fdebad..ad716e11 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -9,7 +9,7 @@ from dataclasses import dataclass, field
import torch
import numpy as np
-from PIL import Image, ImageOps
+from PIL import Image, ImageOps, ImageFilter
import random
import cv2
from skimage import exposure
@@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
+
+
def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
@@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays):
overlay = overlays[index]
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -140,6 +146,7 @@ class StableDiffusionProcessing:
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
+ masks_for_overlay: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = 0
@@ -865,11 +872,66 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
+ # todo: generate masks the old fashioned way
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ # Generate the mask(s) based on similarity between the original and denoised latent vectors
+ if getattr(p, "image_mask", None) is not None:
+ # latent_mask = p.nmask[0].float().cpu()
+
+ # convert the original mask into a form we use to scale distances for thresholding
+ # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
+ # mask_scalar = mask_scalar / (1.00001-mask_scalar)
+ # mask_scalar = mask_scalar.numpy()
+
+ latent_orig = p.init_latent
+ latent_proc = samples_ddim
+ latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1)
+
+ kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
+
+ for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)):
+ converted_mask = distance_map.float().cpu().numpy()
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.9, percentile_max=1, min_width=1)
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.25, percentile_max=0.75, min_width=1)
+
+ # The distance at which opacity of original decreases to 50%
+ # half_weighted_distance = 1 # * mask_scalar
+ # converted_mask = converted_mask / half_weighted_distance
+
+ converted_mask = 1 / (1 + converted_mask ** 2)
+ converted_mask = images.smootherstep(converted_mask)
+ converted_mask = 1 - converted_mask
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, p.width, p.height)
+ converted_mask = create_binary_mask(converted_mask)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if p.paste_to is not None:
+ converted_mask = uncrop(converted_mask,
+ (overlay_image.width, overlay_image.height),
+ p.paste_to)
+
+ p.masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ p.overlay_images[i] = image_masked.convert('RGBA')
+
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim,
+ target_device=devices.cpu,
+ check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -892,7 +954,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = batch_params.images
def infotext(index=0, use_main_prompt=False):
- return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds,
+ use_main_prompt=use_main_prompt, index=index,
+ all_negative_prompts=p.negative_prompts)
save_samples = p.save_samples()
@@ -923,19 +987,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if save_samples:
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i],
+ p.prompts[i], opts.samples_format, info=infotext(i), p=p)
text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
- if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
- image_mask = p.mask_for_overlay.convert('RGB')
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
+ image_mask = p.masks_for_overlay[i].convert('RGB')
+ image_mask_composite = Image.composite(
+ original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size),
+ images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
@@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
nmask: torch.Tensor = field(default=None, init=False)
image_conditioning: torch.Tensor = field(default=None, init=False)
init_img_hash: str = field(default=None, init=False)
- mask_for_overlay: Image = field(default=None, init=False)
init_latent: torch.Tensor = field(default=None, init=False)
def __post_init__(self):
@@ -1415,12 +1486,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
- np_mask = np.array(image_mask).astype(np.float32)
- np_mask /= 255
- np_mask = 1-pow(1-np_mask, 100)
- np_mask *= 255
- np_mask = np.clip(np_mask, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@@ -1431,13 +1496,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
- np_mask = np.array(image_mask).astype(np.float32)
- np_mask /= 255
- np_mask = 1-pow(1-np_mask, 100)
- np_mask *= 255
- np_mask = np.clip(np_mask, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
+ self.masks_for_overlay = []
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@@ -1459,10 +1519,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
-
- self.overlay_images.append(image_masked.convert('RGBA'))
+ self.overlay_images.append(image)
+ self.masks_for_overlay.append(image_mask)
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
@@ -1486,6 +1544,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+ if self.masks_for_overlay is not None:
+ self.masks_for_overlay = self.masks_for_overlay * self.batch_size
+
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size