aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/images.py191
-rw-r--r--modules/img2img.py48
-rw-r--r--modules/processing.py123
-rw-r--r--modules/sd_samplers_cfg_denoiser.py75
-rw-r--r--modules/sd_samplers_common.py3
-rw-r--r--modules/ui.py9
-rw-r--r--scripts/outpainting_mk_2.py10
-rw-r--r--scripts/poor_mans_outpainting.py11
-rw-r--r--test/test_img2img.py3
9 files changed, 437 insertions, 36 deletions
diff --git a/modules/images.py b/modules/images.py
index daf4eebe..6648097e 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -791,3 +791,194 @@ def flatten(img, bgcolor):
img = background
return img.convert('RGB')
+
+
+def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
+ """
+ Generalization convolution filter capable of applying
+ weighted mean, median, maximum, and minimum filters
+ parametrically using an arbitrary kernel.
+
+ Args:
+ img (nparray):
+ The image, a 2-D array of floats, to which the filter is being applied.
+ kernel (nparray):
+ The kernel, a 2-D array of floats.
+ kernel_center (nparray):
+ The kernel center coordinate, a 1-D array with two elements.
+ percentile_min (float):
+ The lower bound of the histogram window used by the filter,
+ from 0 to 1.
+ percentile_max (float):
+ The upper bound of the histogram window used by the filter,
+ from 0 to 1.
+ min_width (float):
+ The minimum size of the histogram window bounds, in weight units.
+ Must be greater than 0.
+
+ Returns:
+ (nparray): A filtered copy of the input image "img", a 2-D array of floats.
+ """
+
+ # Converts an index tuple into a vector.
+ def vec(x):
+ return np.array(x)
+
+ kernel_min = -kernel_center
+ kernel_max = vec(kernel.shape) - kernel_center
+
+ def weighted_histogram_filter_single(idx):
+ idx = vec(idx)
+ min_index = np.maximum(0, idx + kernel_min)
+ max_index = np.minimum(vec(img.shape), idx + kernel_max)
+ window_shape = max_index - min_index
+
+ class WeightedElement:
+ """
+ An element of the histogram, its weight
+ and bounds.
+ """
+ def __init__(self, value, weight):
+ self.value: float = value
+ self.weight: float = weight
+ self.window_min: float = 0.0
+ self.window_max: float = 1.0
+
+ # Collect the values in the image as WeightedElements,
+ # weighted by their corresponding kernel values.
+ values = []
+ for window_tup in np.ndindex(tuple(window_shape)):
+ window_index = vec(window_tup)
+ image_index = window_index + min_index
+ centered_kernel_index = image_index - idx
+ kernel_index = centered_kernel_index + kernel_center
+ element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
+ values.append(element)
+
+ def sort_key(x: WeightedElement):
+ return x.value
+
+ values.sort(key=sort_key)
+
+ # Calculate the height of the stack (sum)
+ # and each sample's range they occupy in the stack
+ sum = 0
+ for i in range(len(values)):
+ values[i].window_min = sum
+ sum += values[i].weight
+ values[i].window_max = sum
+
+ # Calculate what range of this stack ("window")
+ # we want to get the weighted average across.
+ window_min = sum * percentile_min
+ window_max = sum * percentile_max
+ window_width = window_max - window_min
+
+ # Ensure the window is within the stack and at least a certain size.
+ if window_width < min_width:
+ window_center = (window_min + window_max) / 2
+ window_min = window_center - min_width / 2
+ window_max = window_center + min_width / 2
+
+ if window_max > sum:
+ window_max = sum
+ window_min = sum - min_width
+
+ if window_min < 0:
+ window_min = 0
+ window_max = min_width
+
+ value = 0
+ value_weight = 0
+
+ # Get the weighted average of all the samples
+ # that overlap with the window, weighted
+ # by the size of their overlap.
+ for i in range(len(values)):
+ if window_min >= values[i].window_max:
+ continue
+ if window_max <= values[i].window_min:
+ break
+
+ s = max(window_min, values[i].window_min)
+ e = min(window_max, values[i].window_max)
+ w = e - s
+
+ value += values[i].value * w
+ value_weight += w
+
+ return value / value_weight if value_weight != 0 else 0
+
+ img_out = img.copy()
+
+ # Apply the kernel operation over each pixel.
+ for index in np.ndindex(img.shape):
+ img_out[index] = weighted_histogram_filter_single(index)
+
+ return img_out
+
+def smoothstep(x):
+ """
+ The smoothstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * (3 - 2 * x)
+
+def smootherstep(x):
+ """
+ The smootherstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * x * (x * (6 * x - 15) + 10)
+
+
+def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
+ """
+ Creates a Gaussian kernel with thresholded edges.
+
+ Args:
+ stddev_radius (float):
+ Standard deviation of the gaussian kernel, in pixels.
+ max_radius (int):
+ The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
+ The kernel is thresholded so that any values one pixel beyond this radius
+ is weighted at 0.
+
+ Returns:
+ (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
+ """
+ # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
+ def gaussian(sqr_mag):
+ return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
+
+ # Helper function for converting a tuple to an array.
+ def vec(x):
+ return np.array(x)
+
+ """
+ Since a gaussian is unbounded, we need to limit ourselves
+ to a finite range.
+ We taper the ends off at the end of that range so they equal zero
+ while preserving the maximum value of 1 at the mean.
+ """
+ zero_radius = max_radius + 1.0
+ gauss_zero = gaussian(zero_radius * zero_radius)
+ gauss_kernel_scale = 1 / (1 - gauss_zero)
+
+ def gaussian_kernel_func(coordinate):
+ x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
+ x = gaussian(x)
+ x -= gauss_zero
+ x /= gauss_kernel_scale
+ x = max(0.0, x)
+ return x
+
+ size = max_radius * 2 + 1
+ kernel_center = max_radius
+ kernel = np.zeros((size, size))
+
+ for index in np.ndindex(kernel.shape):
+ kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
+
+ return kernel, kernel_center
+
diff --git a/modules/img2img.py b/modules/img2img.py
index c583290a..596f741c 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -146,7 +146,47 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
return batch_results
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
+def img2img(id_task: str,
+ mode: int,
+ prompt: str,
+ negative_prompt: str,
+ prompt_styles,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ inpaint_color_sketch_orig,
+ init_img_inpaint,
+ init_mask_inpaint,
+ steps: int,
+ sampler_name: str,
+ mask_blur: int,
+ mask_alpha: float,
+ mask_blend_power: float,
+ mask_blend_scale: float,
+ inpaint_detail_preservation: float,
+ inpainting_fill: int,
+ n_iter: int,
+ batch_size: int,
+ cfg_scale: float,
+ image_cfg_scale: float,
+ denoising_strength: float,
+ selected_scale_tab: int,
+ height: int,
+ width: int,
+ scale_by: float,
+ resize_mode: int,
+ inpaint_full_res: bool,
+ inpaint_full_res_padding: int,
+ inpainting_mask_invert: int,
+ img2img_batch_input_dir: str,
+ img2img_batch_output_dir: str,
+ img2img_batch_inpaint_mask_dir: str,
+ override_settings_texts,
+ img2img_batch_use_png_info: bool,
+ img2img_batch_png_info_props: list,
+ img2img_batch_png_info_dir: str,
+ request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -204,6 +244,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
init_images=[image],
mask=mask,
mask_blur=mask_blur,
+ mask_blend_power=mask_blend_power,
+ mask_blend_scale=mask_blend_scale,
+ inpaint_detail_preservation=inpaint_detail_preservation,
inpainting_fill=inpainting_fill,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
@@ -224,6 +267,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if mask:
p.extra_generation_params["Mask blur"] = mask_blur
+ p.extra_generation_params["Mask blending bias"] = mask_blend_power
+ p.extra_generation_params["Mask blending preservation"] = mask_blend_scale
+ p.extra_generation_params["Mask blending contrast boost"] = inpaint_detail_preservation
with closing(p):
if is_batch:
diff --git a/modules/processing.py b/modules/processing.py
index ac58ef86..66aaab83 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -9,7 +9,7 @@ from dataclasses import dataclass, field
import torch
import numpy as np
-from PIL import Image, ImageOps
+from PIL import Image, ImageOps, ImageFilter
import random
import cv2
from skimage import exposure
@@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
+
+
def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
@@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays):
overlay = overlays[index]
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -83,7 +89,7 @@ def apply_overlay(image, paste_loc, index, overlays):
def create_binary_mask(image):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
@@ -140,6 +146,7 @@ class StableDiffusionProcessing:
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
+ masks_for_overlay: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = None
@@ -319,9 +326,6 @@ class StableDiffusionProcessing:
conditioning_mask = np.array(image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
-
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@@ -869,10 +873,65 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
+ # todo: generate masks the old fashioned way
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ # Generate the mask(s) based on similarity between the original and denoised latent vectors
+ if getattr(p, "image_mask", None) is not None:
+ # latent_mask = p.nmask[0].float().cpu()
+
+ # convert the original mask into a form we use to scale distances for thresholding
+ # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
+ # mask_scalar = mask_scalar / (1.00001-mask_scalar)
+ # mask_scalar = mask_scalar.numpy()
+
+ latent_orig = p.init_latent
+ latent_proc = samples_ddim
+ latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1)
+
+ kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
+
+ for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)):
+ converted_mask = distance_map.float().cpu().numpy()
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.9, percentile_max=1, min_width=1)
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.25, percentile_max=0.75, min_width=1)
+
+ # The distance at which opacity of original decreases to 50%
+ # half_weighted_distance = 1 # * mask_scalar
+ # converted_mask = converted_mask / half_weighted_distance
+
+ converted_mask = 1 / (1 + converted_mask ** 2)
+ converted_mask = images.smootherstep(converted_mask)
+ converted_mask = 1 - converted_mask
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, p.width, p.height)
+ converted_mask = create_binary_mask(converted_mask)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if p.paste_to is not None:
+ converted_mask = uncrop(converted_mask,
+ (overlay_image.width, overlay_image.height),
+ p.paste_to)
+
+ p.masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ p.overlay_images[i] = image_masked.convert('RGBA')
+
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim,
+ target_device=devices.cpu,
+ check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -897,7 +956,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = batch_params.images
def infotext(index=0, use_main_prompt=False):
- return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds,
+ use_main_prompt=use_main_prompt, index=index,
+ all_negative_prompts=p.negative_prompts)
save_samples = p.save_samples()
@@ -928,19 +989,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if save_samples:
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i],
+ p.prompts[i], opts.samples_format, info=infotext(i), p=p)
text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
- if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
- image_mask = p.mask_for_overlay.convert('RGB')
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
+ image_mask = p.masks_for_overlay[i].convert('RGB')
+ image_mask_composite = Image.composite(
+ original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size),
+ images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
@@ -1352,6 +1421,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
+ mask_blend_power: float = 1
+ mask_blend_scale: float = 0.5
+ inpaint_detail_preservation: float = 4
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
@@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
nmask: torch.Tensor = field(default=None, init=False)
image_conditioning: torch.Tensor = field(default=None, init=False)
init_img_hash: str = field(default=None, init=False)
- mask_for_overlay: Image = field(default=None, init=False)
init_latent: torch.Tensor = field(default=None, init=False)
def __post_init__(self):
@@ -1415,7 +1486,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
- self.mask_for_overlay = image_mask
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@@ -1426,10 +1496,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
- np_mask = np.array(image_mask)
- np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
+ self.masks_for_overlay = []
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@@ -1451,10 +1519,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
-
- self.overlay_images.append(image_masked.convert('RGBA'))
+ self.overlay_images.append(image)
+ self.masks_for_overlay.append(image_mask)
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
@@ -1478,6 +1544,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+ if self.masks_for_overlay is not None:
+ self.masks_for_overlay = self.masks_for_overlay * self.batch_size
+
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size
@@ -1504,7 +1573,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
- latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@@ -1527,9 +1595,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
- if self.mask is not None:
- samples = samples * self.nmask + self.init_latent * self.mask
-
del x
devices.torch_gc()
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index b8101d38..efbe7a40 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -43,6 +43,9 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None
self.mask = None
self.nmask = None
+ self.mask_blend_power = 1
+ self.mask_blend_scale = 0.5
+ self.inpaint_detail_preservation = 4
self.init_latent = None
self.steps = None
"""number of steps as specified by user in UI"""
@@ -56,6 +59,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -89,6 +95,69 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+ def latent_blend(a, b, t):
+ """
+ Interpolates two latent image representations according to the parameter t,
+ where the interpolated vectors' magnitudes are also interpolated separately.
+ The "detail_preservation" factor biases the magnitude interpolation towards
+ the larger of the two magnitudes.
+ """
+ # NOTE: We use inplace operations wherever possible.
+
+ one_minus_t = 1 - t
+
+ # Linearly interpolate the image vectors.
+ a_scaled = a * one_minus_t
+ b_scaled = b * t
+ image_interp = a_scaled
+ image_interp.add_(b_scaled)
+ result_type = image_interp.dtype
+ del a_scaled, b_scaled
+
+ # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
+ # 64-bit operations are used here to allow large exponents.
+ current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
+
+ # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
+ a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t
+ b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t
+ desired_magnitude = a_magnitude
+ desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation)
+ del a_magnitude, b_magnitude, one_minus_t
+
+ # Change the linearly interpolated image vectors' magnitudes to the value we want.
+ # This is the last 64-bit operation.
+ image_interp_scaling_factor = desired_magnitude
+ image_interp_scaling_factor.div_(current_magnitude)
+ image_interp_scaled = image_interp
+ image_interp_scaled.mul_(image_interp_scaling_factor)
+ del current_magnitude
+ del desired_magnitude
+ del image_interp
+ del image_interp_scaling_factor
+
+ image_interp_scaled = image_interp_scaled.to(result_type)
+ del result_type
+
+ return image_interp_scaled
+
+ def get_modified_nmask(nmask, _sigma):
+ """
+ Converts a negative mask representing the transparency of the original latent vectors being overlayed
+ to a mask that is scaled according to the denoising strength for this step.
+
+ Where:
+ 0 = fully opaque, infinite density, fully masked
+ 1 = fully transparent, zero density, fully unmasked
+
+ We bring this transparency to a power, as this allows one to simulate N number of blending operations
+ where N can be any positive real value. Using this one can control the balance of influence between
+ the denoiser and the original latents according to the sigma value.
+
+ NOTE: "mask" is not used
+ """
+ return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale)
+
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -105,8 +174,9 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma))
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +277,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma))
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 58efcad2..ecd8ab0a 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -277,6 +277,9 @@ class Sampler:
self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None
+ self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None
+ self.model_wrap_cfg.inpaint_detail_preservation = p.inpaint_detail_preservation if hasattr(p, 'inpaint_detail_preservation') else None
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
diff --git a/modules/ui.py b/modules/ui.py
index 08e0ad77..b13ed66c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -678,6 +678,9 @@ def create_ui():
with FormRow():
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+ mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power")
+ mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale")
+ inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset")
with FormRow():
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
@@ -733,6 +736,9 @@ def create_ui():
sampler_name,
mask_blur,
mask_alpha,
+ mask_blend_power,
+ mask_blend_scale,
+ inpaint_detail_preservation,
inpainting_fill,
batch_count,
batch_size,
@@ -831,6 +837,9 @@ def create_ui():
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
+ (mask_blend_power, "Mask blending bias"),
+ (mask_blend_scale, "Mask blending preservation"),
+ (inpaint_detail_preservation, "Mask blending contrast boost"),
*scripts.scripts_img2img.infotext_fields
]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index c98ab480..bd9cb61b 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -133,13 +133,16 @@ class Script(scripts.Script):
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
+ mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power"))
+ mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale"))
+ inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation"))
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
- return [info, pixels, mask_blur, direction, noise_q, color_variation]
+ return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation]
- def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
+ def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation):
initial_seed_and_info = [None, None]
process_width = p.width
@@ -167,6 +170,9 @@ class Script(scripts.Script):
p.mask_blur_x = mask_blur_x*4
p.mask_blur_y = mask_blur_y*4
+ p.mask_blend_power = mask_blend_power
+ p.mask_blend_scale = mask_blend_scale
+ p.inpaint_detail_preservation = inpaint_detail_preservation
init_img = p.init_images[0]
target_w = math.ceil((init_img.width + left + right) / 64) * 64
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
index ea0632b6..5388f5db 100644
--- a/scripts/poor_mans_outpainting.py
+++ b/scripts/poor_mans_outpainting.py
@@ -22,16 +22,23 @@ class Script(scripts.Script):
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
+ mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power"))
+ mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale"))
+ inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation"))
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
- return [pixels, mask_blur, inpainting_fill, direction]
+ return [pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction]
- def run(self, p, pixels, mask_blur, inpainting_fill, direction):
+ def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction):
initial_seed = None
initial_info = None
p.mask_blur = mask_blur * 2
+ p.mask_blend_power = mask_blend_power
+ p.mask_blend_scale = mask_blend_scale
+ p.inpaint_detail_preservation = inpaint_detail_preservation
+
p.inpainting_fill = inpainting_fill
p.inpaint_full_res = False
diff --git a/test/test_img2img.py b/test/test_img2img.py
index 117d2d1e..5cda2dba 100644
--- a/test/test_img2img.py
+++ b/test/test_img2img.py
@@ -24,6 +24,9 @@ def simple_img2img_request(img2img_basic_image_base64):
"inpainting_mask_invert": False,
"mask": None,
"mask_blur": 4,
+ "mask_blend_power": 1,
+ "mask_blend_scale": 0.5,
+ "inpaint_detail_preservation": 4,
"n_iter": 1,
"negative_prompt": "",
"override_settings": {},