aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py53
1 files changed, 32 insertions, 21 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 24c537d1..4a406084 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -27,6 +27,7 @@ from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
+from blendmodes.blend import blendLayers, BlendType
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -39,17 +40,19 @@ def setup_color_correction(image):
return correction_target
-def apply_color_correction(correction, image):
+def apply_color_correction(correction, original_image):
logging.info("Applying color correction.")
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
- np.asarray(image),
+ np.asarray(original_image),
cv2.COLOR_RGB2LAB
),
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
-
+
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
+
return image
@@ -77,7 +80,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -118,6 +121,7 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
+ self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
if not seed_enable_extras:
@@ -147,11 +151,11 @@ class StableDiffusionProcessing():
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
+ image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
@@ -199,7 +203,7 @@ class StableDiffusionProcessing():
source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
)
-
+
# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
@@ -314,7 +318,7 @@ class Processed:
return json.dumps(obj)
- def infotext(self, p: StableDiffusionProcessing, index):
+ def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
@@ -429,6 +433,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
@@ -446,7 +451,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
+ negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -463,12 +468,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
res = process_images_inner(p)
- finally: # restore opts to original state
- for k, v in stored_opts.items():
- setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ finally:
+ # restore opts to original state
+ if p.override_settings_restore_afterwards:
+ for k, v in stored_opts.items():
+ setattr(opts, k, v)
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks()
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
+ if k == 'sd_vae': sd_vae.reload_vae_weights()
return res
@@ -537,7 +544,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
for n in range(p.n_iter):
if state.skipped:
state.skipped = False
-
+
if state.interrupted:
break
@@ -612,7 +619,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
- del x_samples_ddim
+ del x_samples_ddim
devices.torch_gc()
@@ -704,7 +711,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
- """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
def save_intermediate(image, index):
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
@@ -720,7 +727,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
- # Avoid making the inpainting conditioning unless necessary as
+ # Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
@@ -829,9 +836,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = []
imgs = []
for img in self.init_images:
- image = img.convert("RGB")
+ image = images.flatten(img, opts.img2img_background_color)
- if crop_region is None:
+ if crop_region is None and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
@@ -840,6 +847,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.overlay_images.append(image_masked.convert('RGBA'))
+ # crop_region is not None if we are doing inpaint full res
if crop_region is not None:
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
@@ -876,6 +884,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ if self.resize_mode == 3:
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
if image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))