aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py115
1 files changed, 85 insertions, 30 deletions
diff --git a/modules/processing.py b/modules/processing.py
index d22b353f..8da73884 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,5 @@
import json
+import logging
import math
import os
import sys
@@ -6,14 +7,14 @@ import hashlib
import torch
import numpy as np
-from PIL import Image, ImageFilter, ImageOps
+from PIL import Image, ImageOps
import random
import cv2
from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -23,7 +24,6 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
-import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
@@ -106,6 +106,9 @@ class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -171,12 +174,13 @@ class StableDiffusionProcessing:
self.prompts = None
self.negative_prompts = None
+ self.extra_network_data = None
self.seeds = None
self.subseeds = None
self.step_multiplier = 1
- self.cached_uc = [None, None]
- self.cached_c = [None, None]
+ self.cached_uc = StableDiffusionProcessing.cached_uc
+ self.cached_c = StableDiffusionProcessing.cached_c
self.uc = None
self.c = None
@@ -288,8 +292,9 @@ class StableDiffusionProcessing:
self.sampler = None
self.c = None
self.uc = None
- self.cached_c = [None, None]
- self.cached_uc = [None, None]
+ if not opts.experimental_persistent_cond_cache:
+ StableDiffusionProcessing.cached_c = [None, None]
+ StableDiffusionProcessing.cached_uc = [None, None]
def get_token_merging_ratio(self, for_hr=False):
if for_hr:
@@ -311,7 +316,7 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
- def get_conds_with_caching(self, function, required_prompts, steps, cache):
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before.
@@ -320,27 +325,29 @@ class StableDiffusionProcessing:
representing the previously used arguments, or None if no arguments
have been used before. The second element is where the previously
computed result is stored.
+
+ caches is a list with items described above.
"""
- if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
- return cache[1]
+ for cache in caches:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
+ return cache[1]
+
+ cache = caches[0]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
return cache[1]
def setup_conds(self):
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
-
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c)
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self):
- self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
-
- return extra_network_data
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
class Processed:
@@ -588,6 +595,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
def process_images(p: StableDiffusionProcessing) -> Processed:
+ if p.scripts is not None:
+ p.scripts.before_process(p)
+
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
@@ -673,10 +683,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
- extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -697,11 +708,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(p.prompts) == 0:
break
- extra_network_data = p.parse_extra_network_prompts()
+ p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
- extra_networks.activate(p, extra_network_data)
+ extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@@ -736,7 +747,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -823,8 +834,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks and extra_network_data:
- extra_networks.deactivate(p, extra_network_data)
+ if not p.disable_extra_networks and p.extra_network_data:
+ extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc()
@@ -859,6 +870,8 @@ def old_hires_fix_first_pass_dimensions(width, height):
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
+ cached_hr_uc = [None, None]
+ cached_hr_c = [None, None]
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
@@ -891,6 +904,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_negative_prompts = None
self.hr_extra_network_data = None
+ self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
+ self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
self.hr_c = None
self.hr_uc = None
@@ -970,7 +985,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
- assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
@@ -1053,6 +1069,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
+ with devices.autocast():
+ self.calculate_hr_conds()
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
@@ -1064,8 +1083,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
def close(self):
+ super().close()
self.hr_c = None
self.hr_uc = None
+ if not opts.experimental_persistent_cond_cache:
+ StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
+ StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
def setup_prompts(self):
super().setup_prompts()
@@ -1092,12 +1115,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+ def calculate_hr_conds(self):
+ if self.hr_c is not None:
+ return
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+
def setup_conds(self):
super().setup_conds()
+ self.hr_uc = None
+ self.hr_c = None
+
if self.enable_hr:
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c)
+ if shared.opts.hires_fix_use_firstpass_conds:
+ self.calculate_hr_conds()
+
+ elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ self.calculate_hr_conds()
+
+ with devices.autocast():
+ extra_networks.activate(self, self.extra_network_data)
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
@@ -1114,7 +1156,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -1125,7 +1167,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_mask = mask
self.latent_mask = None
self.mask_for_overlay = None
- self.mask_blur = mask_blur
+ if mask_blur is not None:
+ mask_blur_x = mask_blur
+ mask_blur_y = mask_blur
+ self.mask_blur_x = mask_blur_x
+ self.mask_blur_y = mask_blur_y
self.inpainting_fill = inpainting_fill
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
@@ -1147,8 +1193,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
- if self.mask_blur > 0:
- image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
+ if self.mask_blur_x > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
+ image_mask = Image.fromarray(np_mask)
+
+ if self.mask_blur_y > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
+ image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
self.mask_for_overlay = image_mask