From 56604f08a18588e8e6b57d7c3f9c61d6624846f8 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:53:44 -0700 Subject: Moved image filters used by soft inpainting into soft_inpainting.py from images.py --- scripts/soft_inpainting.py | 205 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 199 insertions(+), 6 deletions(-) (limited to 'scripts/soft_inpainting.py') diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 6d0cf847..1f451b55 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -1,4 +1,6 @@ +import numpy as np import gradio as gr +import math from modules.ui_components import InputAccordion import modules.scripts as scripts @@ -101,7 +103,6 @@ def apply_adaptive_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -115,15 +116,15 @@ def apply_adaptive_masks( latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) masks_for_overlay = [] for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% @@ -131,7 +132,7 @@ def apply_adaptive_masks( # converted_mask = converted_mask / half_weighted_distance converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) + converted_mask = smootherstep(converted_mask) converted_mask = 1 - converted_mask converted_mask = 255. * converted_mask converted_mask = converted_mask.astype(np.uint8) @@ -166,7 +167,6 @@ def apply_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -202,6 +202,196 @@ def apply_masks( return masks_for_overlay +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + # ------------------- Constants ------------------- @@ -232,6 +422,9 @@ el_ids = SoftInpaintingSettings( "inpaint_detail_preservation") +# ----- + + class Script(scripts.Script): def __init__(self): -- cgit v1.2.1