aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py64
1 files changed, 42 insertions, 22 deletions
diff --git a/modules/processing.py b/modules/processing.py
index b1df4918..03c9143d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return torch.zeros(
- x.shape[0], 5, 1, 1,
- dtype=x.dtype,
- device=x.device
- )
+ return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
- return torch.zeros(
- latent_image.shape[0], 5, 1, 1,
- dtype=latent_image.dtype,
- device=latent_image.device
- )
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
@@ -174,11 +166,11 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
- conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(source_image.device)
+ conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -199,9 +191,13 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
+ def close(self):
+ self.sd_model = None
+ self.sampler = None
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
@@ -422,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
- opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+ setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p)
finally:
for k, v in stored_opts.items():
- opts.data[k] = v
+ setattr(opts, k, v)
return res
@@ -505,6 +501,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
+ if p.scripts is not None:
+ p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -517,7 +516,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -645,7 +644,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
@@ -658,9 +657,28 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
+ def save_intermediate(image, index):
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ return
+
+ if not isinstance(image, Image.Image):
+ image = sd_samplers.sample_to_image(image, index)
+
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+
if opts.use_scale_latent_for_hires_fix:
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
+
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -670,6 +688,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
+
+ save_intermediate(image, i)
+
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
@@ -681,14 +702,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
+
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- image_conditioning = self.txt2img_image_conditioning(x)
-
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
@@ -827,8 +848,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
@@ -839,4 +859,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
- return samples \ No newline at end of file
+ return samples