aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py118
1 files changed, 88 insertions, 30 deletions
diff --git a/modules/processing.py b/modules/processing.py
index a172af0b..7e853287 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
@@ -136,28 +154,12 @@ class StableDiffusionProcessing():
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
+ self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- self.is_using_inpainting_conditioning = True
-
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
@@ -420,7 +422,7 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed)
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -544,6 +546,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ p.iteration = n
+
if state.skipped:
state.skipped = False
@@ -658,12 +662,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
self.hr_scale = hr_scale
self.hr_upscaler = hr_upscaler
+ self.hr_second_pass_steps = hr_second_pass_steps
+ self.hr_resize_x = hr_resize_x
+ self.hr_resize_y = hr_resize_y
+ self.hr_upscale_to_x = hr_resize_x
+ self.hr_upscale_to_y = hr_resize_y
if firstphase_width != 0 or firstphase_height != 0:
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
@@ -671,14 +680,60 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.width = firstphase_width
self.height = firstphase_height
+ self.truncate_x = 0
+ self.truncate_y = 0
+
+
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if state.job_count == -1:
- state.job_count = self.n_iter * 2
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
else:
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
+
+ if self.hr_resize_y == 0:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ elif self.hr_resize_x == 0:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+ else:
+ target_w = self.hr_resize_x
+ target_h = self.hr_resize_y
+ src_ratio = self.width / self.height
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
+
+ if src_ratio < dst_ratio:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ else:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
+
+ # special case: the user has chosen to do nothing
+ if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
+ self.enable_hr = False
+ self.denoising_strength = None
+ self.extra_generation_params.pop("Hires upscale", None)
+ self.extra_generation_params.pop("Hires resize", None)
+ return
+
+ if not state.processing_has_refined_job_count:
+ if state.job_count == -1:
+ state.job_count = self.n_iter
+
+ shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
state.job_count = state.job_count * 2
+ state.processing_has_refined_job_count = True
+
+ if self.hr_second_pass_steps:
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
- self.extra_generation_params["Hires upscale"] = self.hr_scale
if self.hr_upscaler is not None:
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
@@ -695,8 +750,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr:
return samples
- target_width = int(self.width * self.hr_scale)
- target_height = int(self.height * self.hr_scale)
+ target_width = self.hr_upscale_to_x
+ target_height = self.hr_upscale_to_y
def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
@@ -705,15 +760,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return
if not isinstance(image, Image.Image):
- image = sd_samplers.sample_to_image(image, index)
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
if latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -750,13 +806,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
return samples