aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py146
1 files changed, 108 insertions, 38 deletions
diff --git a/modules/processing.py b/modules/processing.py
index ea926fc3..c61bbfbd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -47,6 +47,25 @@ def apply_color_correction(correction, image):
return image
+def apply_overlay(image, paste_loc, index, overlays):
+ if overlays is None or index >= len(overlays):
+ return image
+
+ overlay = overlays[index]
+
+ if paste_loc is not None:
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
+
+ return image
+
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
@@ -104,6 +123,12 @@ class StableDiffusionProcessing():
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
+ self.scripts = None
+ self.script_args = None
+ self.all_prompts = None
+ self.all_seeds = None
+ self.all_subseeds = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -304,7 +329,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -318,7 +343,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
@@ -350,32 +375,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
- all_prompts = p.prompt
+ p.all_prompts = p.prompt
else:
- all_prompts = p.batch_size * p.n_iter * [p.prompt]
+ p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(seed) == list:
- all_seeds = seed
+ p.all_seeds = seed
else:
- all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
- all_subseeds = subseed
+ p.all_subseeds = subseed
else:
- all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ if p.scripts is not None:
+ p.scripts.run_alwayson_scripts(p)
+
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
- p.init(all_prompts, all_seeds, all_subseeds)
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
@@ -387,15 +415,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -442,22 +468,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- if p.overlay_images is not None and i < len(p.overlay_images):
- overlay = p.overlay_images[i]
-
- if p.paste_to is not None:
- x, y, w, h = p.paste_to
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
-
- image = image.convert('RGBA')
- image.alpha_composite(overlay)
- image = image.convert('RGB')
+ image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
@@ -490,10 +505,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+ return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@@ -515,6 +530,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
if self.firstphase_width == 0 or self.firstphase_height == 0:
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
@@ -536,21 +553,40 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
+ def create_dummy_mask(self, x, width=None, height=None):
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ height = height or self.height
+ width = width or self.width
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -587,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
@@ -595,7 +631,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -613,6 +649,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
+ self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
@@ -685,6 +722,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
@@ -714,10 +755,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+ else:
+ self.image_conditioning = torch.zeros(
+ self.init_latent.shape[0], 5, 1, 1,
+ dtype=self.init_latent.dtype,
+ device=self.init_latent.device
+ )
+
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask