aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/processing.py94
-rw-r--r--modules/scripts.py70
-rw-r--r--modules/sd_samplers_cfg_denoiser.py23
3 files changed, 118 insertions, 69 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 7d46949f..5a1a90af 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -30,7 +30,6 @@ import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
-import modules.soft_inpainting as si
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
@@ -73,12 +72,10 @@ def uncrop(image, dest_size, paste_loc):
return image
-def apply_overlay(image, paste_loc, index, overlays):
- if overlays is None or index >= len(overlays):
+def apply_overlay(image, paste_loc, overlay):
+ if overlay is None:
return image
- overlay = overlays[index]
-
if paste_loc is not None:
image = uncrop(image, (overlay.width, overlay.height), paste_loc)
@@ -150,7 +147,6 @@ class StableDiffusionProcessing:
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
- masks_for_overlay: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = None
@@ -880,31 +876,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
+ if p.scripts is not None:
+ ps = scripts.PostSampleArgs(samples_ddim)
+ p.scripts.post_sample(p, ps)
+ samples_ddim = pp.samples
+
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
- # todo: generate adaptive masks based on pixel differences.
- if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
- si.apply_masks(soft_inpainting=p.soft_inpainting,
- nmask=p.nmask,
- overlay_images=p.overlay_images,
- masks_for_overlay=p.masks_for_overlay,
- width=p.width,
- height=p.height,
- paste_to=p.paste_to)
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
- # Generate the mask(s) based on similarity between the original and denoised latent vectors
- if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
- si.apply_adaptive_masks(latent_orig=p.init_latent,
- latent_processed=samples_ddim,
- overlay_images=p.overlay_images,
- masks_for_overlay=p.masks_for_overlay,
- width=p.width,
- height=p.height,
- paste_to=p.paste_to)
-
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -955,9 +937,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
+
+ mask_for_overlay = p.mask_for_overlay
+ overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None
+
+ if p.scripts is not None:
+ ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
+ p.scripts.postprocess_maskoverlay(p, ppmo)
+ mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image
+
if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
- image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
@@ -968,9 +959,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
original_denoised_image = image.copy()
if p.paste_to is not None:
- original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to)
+ original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to)
- image = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ image = apply_overlay(image, p.paste_to, overlay_image)
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -981,13 +972,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
- mask_for_overlay = p.mask_for_overlay
- elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]:
- mask_for_overlay = p.masks_for_overlay[i]
- else:
- mask_for_overlay = None
-
if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
image_mask = mask_for_overlay.convert('RGB')
@@ -1401,7 +1385,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
- soft_inpainting: si.SoftInpaintingParameters = si.default
+ mask_round: bool = True
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
@@ -1447,7 +1431,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
- image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None))
+ image_mask = create_binary_mask(image_mask, round=self.mask_round)
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
@@ -1465,7 +1449,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
- self.mask_for_overlay = image_mask if self.soft_inpainting is None else None
+ self.mask_for_overlay = image_mask
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@@ -1476,13 +1460,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
+ np_mask = np.array(image_mask)
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
+ self.mask_for_overlay = Image.fromarray(np_mask)
- if self.soft_inpainting is None:
- np_mask = np.array(image_mask)
- np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
-
- self.masks_for_overlay = [] if self.soft_inpainting is not None else None
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@@ -1504,15 +1485,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
- if self.soft_inpainting is not None:
- # We apply the masks AFTER to adjust mask based on changed content.
- self.overlay_images.append(image.convert('RGBA'))
- self.masks_for_overlay.append(image_mask)
- else:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
- self.overlay_images.append(image_masked.convert('RGBA'))
+ self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
@@ -1565,7 +1541,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
- if self.soft_inpainting is None:
+ if self.mask_round:
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
@@ -1578,7 +1554,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
@@ -1589,8 +1565,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
- if self.mask is not None and self.soft_inpainting is None:
- samples = samples * self.nmask + self.init_latent * self.mask
+ blended_samples = samples * self.nmask + self.init_latent * self.mask
+
+ if self.scripts is not None:
+ mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True)
+ self.scripts.on_mask_blend(self, mba)
+ blended_samples = mba.blended_latent
+
+ samples = blended_samples
del x
devices.torch_gc()
diff --git a/modules/scripts.py b/modules/scripts.py
index 7f9454eb..92a07c56 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object()
+class MaskBlendArgs:
+ def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None):
+ self.current_latent = current_latent
+ self.nmask = nmask
+ self.init_latent = init_latent
+ self.mask = mask
+ self.blended_samples = blended_samples
+
+ self.denoiser = denoiser
+ self.is_final_blend = denoiser is None
+ self.sigma = sigma
+
+class PostSampleArgs:
+ def __init__(self, samples):
+ self.samples = samples
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
+class PostProcessMaskOverlayArgs:
+ def __init__(self, index, mask_for_overlay, overlay_image):
+ self.index = index
+ self.mask_for_overlay = mask_for_overlay
+ self.overlay_image = overlay_image
class PostprocessBatchListArgs:
def __init__(self, images):
@@ -206,6 +226,25 @@ class Script:
pass
+ def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
+ """
+ Called in inpainting mode when the original content is blended with the inpainted content.
+ This is called at every step in the denoising process and once at the end.
+ If is_final_blend is true, this is called for the final blending stage.
+ Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
+ """
+
+ pass
+
+ def post_sample(self, p, ps: PostSampleArgs, *args):
+ """
+ Called after the samples have been generated,
+ but before they have been decoded by the VAE, if applicable.
+ Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
+ """
+
+ pass
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
@@ -213,6 +252,13 @@ class Script:
pass
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -767,6 +813,22 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
+ def post_sample(self, p, ps: PostSampleArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.post_sample(p, ps, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
+ def on_mask_blend(self, p, mba: MaskBlendArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.on_mask_blend(p, mba, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
@@ -775,6 +837,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_maskoverlay(p, ppmo, *script_args)
+ except Exception:
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index f13e8dcc..eb9d5daf 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -109,19 +109,16 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
# If we use masks, blending between the denoised and original latent images occurs here.
- def apply_blend(latent):
- if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function):
- return self.p.denoiser_masked_blend_function(
- self,
- # Using an argument dictionary so that arguments can be added without breaking extensions.
- args=
- {
- "denoiser": self,
- "current_latent": latent,
- "sigma": sigma
- })
- else:
- return self.init_latent * self.mask + self.nmask * latent
+ def apply_blend(current_latent):
+ blended_latent = current_latent * self.nmask + self.init_latent * self.mask
+
+ if self.p.scripts is not None:
+ from modules import scripts
+ mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
+ self.p.scripts.on_mask_blend(self.p, mba)
+ blended_latent = mba.blended_latent
+
+ return blended_latent
# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None: