aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py176
1 files changed, 90 insertions, 86 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 24c537d1..47712159 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -27,6 +27,7 @@ from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
+from blendmodes.blend import blendLayers, BlendType
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -39,17 +40,19 @@ def setup_color_correction(image):
return correction_target
-def apply_color_correction(correction, image):
+def apply_color_correction(correction, original_image):
logging.info("Applying color correction.")
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
- np.asarray(image),
+ np.asarray(original_image),
cv2.COLOR_RGB2LAB
),
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
-
+
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
+
return image
@@ -73,11 +76,29 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -118,6 +139,7 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
+ self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
if not seed_enable_extras:
@@ -132,28 +154,12 @@ class StableDiffusionProcessing():
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
+ self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
-
- self.is_using_inpainting_conditioning = True
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
@@ -199,7 +205,7 @@ class StableDiffusionProcessing():
source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
)
-
+
# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
@@ -235,7 +241,7 @@ class StableDiffusionProcessing():
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -243,6 +249,7 @@ class Processed:
self.subseed = subseed
self.subseed_strength = p.subseed_strength
self.info = info
+ self.comments = comments
self.width = p.width
self.height = p.height
self.sampler_name = p.sampler_name
@@ -314,7 +321,7 @@ class Processed:
return json.dumps(obj)
- def infotext(self, p: StableDiffusionProcessing, index):
+ def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
@@ -334,13 +341,14 @@ def slerp(val, low, high):
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+ eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -380,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
- if opts.eta_noise_seed_delta > 0:
- torch.manual_seed(seed + opts.eta_noise_seed_delta)
+ if eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -414,7 +422,7 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed)
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -429,6 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
@@ -446,7 +455,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
+ negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -463,12 +472,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
res = process_images_inner(p)
- finally: # restore opts to original state
- for k, v in stored_opts.items():
- setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ finally:
+ # restore opts to original state
+ if p.override_settings_restore_afterwards:
+ for k, v in stored_opts.items():
+ setattr(opts, k, v)
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks()
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
+ if k == 'sd_vae': sd_vae.reload_vae_weights()
return res
@@ -535,9 +546,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ p.iteration = n
+
if state.skipped:
state.skipped = False
-
+
if state.interrupted:
break
@@ -612,7 +625,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
- del x_samples_ddim
+ del x_samples_ddim
devices.torch_gc()
@@ -638,7 +651,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+ res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None:
p.scripts.postprocess(p, res)
@@ -649,14 +662,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
- self.firstphase_width = firstphase_width
- self.firstphase_height = firstphase_height
- self.truncate_x = 0
- self.truncate_y = 0
+ self.hr_scale = hr_scale
+ self.hr_upscaler = hr_upscaler
+
+ if firstphase_width != 0 or firstphase_height != 0:
+ print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
+ self.hr_scale = self.width / firstphase_width
+ self.width = firstphase_width
+ self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -665,62 +682,45 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
-
- if self.firstphase_width == 0 or self.firstphase_height == 0:
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- firstphase_width_truncated = int(scale * self.width)
- firstphase_height_truncated = int(scale * self.height)
-
- else:
-
- width_ratio = self.width / self.firstphase_width
- height_ratio = self.height / self.firstphase_height
-
- if width_ratio > height_ratio:
- firstphase_width_truncated = self.firstphase_width
- firstphase_height_truncated = self.firstphase_width * self.height / self.width
- else:
- firstphase_width_truncated = self.firstphase_height * self.width / self.height
- firstphase_height_truncated = self.firstphase_height
-
- self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
- self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ if self.hr_upscaler is not None:
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and latent_scale_mode is None:
+ assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+
if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
- x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
+ target_width = int(self.width * self.hr_scale)
+ target_height = int(self.height * self.hr_scale)
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
-
- """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index):
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
+
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
- image = sd_samplers.sample_to_image(image, index)
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
- if opts.use_scale_latent_for_hires_fix:
+ if latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
- # Avoid making the inpainting conditioning unless necessary as
+ # Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
@@ -738,7 +738,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i)
- image = images.resize_image(0, image, self.width, self.height)
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
@@ -755,7 +755,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
@@ -829,9 +829,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = []
imgs = []
for img in self.init_images:
- image = img.convert("RGB")
+ image = images.flatten(img, opts.img2img_background_color)
- if crop_region is None:
+ if crop_region is None and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
@@ -840,6 +840,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.overlay_images.append(image_masked.convert('RGBA'))
+ # crop_region is not None if we are doing inpaint full res
if crop_region is not None:
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
@@ -876,6 +877,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ if self.resize_mode == 3:
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
if image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))