aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorRoy Shilkrot <roy.shil@gmail.com>2022-10-31 11:45:52 -0400
committerRoy Shilkrot <roy.shil@gmail.com>2022-10-31 11:45:52 -0400
commit509fd1459bacfe0dcde1e5792ea3e046aa9f4d24 (patch)
treeeaecb0351d6a95e4c29eaebdcb8b304990fd3a91 /modules/processing.py
parentbdc90837987ed8919dd611fd01553b0c170ded5c (diff)
parentdc7425a56e7a014cbfa3b3d44ad2321e519fe378 (diff)
Merge remote-tracking branch 'upstream/master' into roy.add_simple_interrogate_api
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py137
1 files changed, 83 insertions, 54 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 4efba946..b1df4918 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -129,6 +129,73 @@ class StableDiffusionProcessing():
self.all_seeds = None
self.all_subseeds = None
+ def txt2img_image_conditioning(self, x, width=None, height=None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return torch.zeros(
+ x.shape[0], 5, 1, 1,
+ dtype=x.dtype,
+ device=x.device
+ )
+
+ height = height or self.height
+ width = width or self.width
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ return torch.zeros(
+ latent_image.shape[0], 5, 1, 1,
+ dtype=latent_image.dtype,
+ device=latent_image.device
+ )
+
+ # Handle the different mask inputs
+ if image_mask is not None:
+ if torch.is_tensor(image_mask):
+ conditioning_mask = image_mask
+ else:
+ conditioning_mask = np.array(image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+ conditioning_mask = conditioning_mask.to(source_image.device)
+ conditioning_image = torch.lerp(
+ source_image,
+ source_image * (1.0 - conditioning_mask),
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+ )
+
+ # Encode the new masked image using first stage of network.
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+ return image_conditioning
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -329,6 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -411,7 +479,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
- p.scripts.run_alwayson_scripts(p)
+ p.scripts.process(p)
infotexts = []
output_images = []
@@ -434,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
with devices.autocast():
@@ -523,7 +591,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@@ -571,37 +645,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def create_dummy_mask(self, x, width=None, height=None):
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- else:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- return image_conditioning
-
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -634,11 +687,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ image_conditioning = self.txt2img_image_conditioning(x)
+
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
@@ -770,33 +825,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
-
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
- else:
- self.image_conditioning = torch.zeros(
- self.init_latent.shape[0], 5, 1, 1,
- dtype=self.init_latent.dtype,
- device=self.init_latent.device
- )
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):