aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py97
1 files changed, 66 insertions, 31 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 8086a2b0..ae58b108 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -539,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+class DecodedSamples(list):
+ already_decoded = True
+
+
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
- samples = []
+ samples = DecodedSamples()
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -788,7 +792,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ if getattr(samples_ddim, 'already_decoded', False):
+ x_samples_ddim = samples_ddim
+ else:
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -930,7 +938,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -941,11 +949,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_checkpoint_name = hr_checkpoint_name
+ self.hr_checkpoint_info = None
self.hr_sampler_name = hr_sampler_name
self.hr_prompt = hr_prompt
self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
+ self.latent_scale_mode = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -968,6 +979,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_checkpoint_name:
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+ if self.hr_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
@@ -977,6 +996,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and self.latent_scale_mode is None:
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
self.hr_resize_y = self.height
@@ -1015,14 +1039,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
- # special case: the user has chosen to do nothing
- if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
- self.enable_hr = False
- self.denoising_strength = None
- self.extra_generation_params.pop("Hires upscale", None)
- self.extra_generation_params.pop("Hires resize", None)
- return
-
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
@@ -1040,17 +1056,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
- if self.enable_hr and latent_scale_mode is None:
- if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
- raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ del x
if not self.enable_hr:
return samples
+ if self.latent_scale_mode is None:
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+ else:
+ decoded_samples = None
+
+ current = shared.sd_model.sd_checkpoint_info
+ try:
+ if self.hr_checkpoint_info is not None:
+ self.sampler = None
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+ devices.torch_gc()
+
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
+ finally:
+ self.sampler = None
+ sd_models.reload_model_weights(info=current)
+ devices.torch_gc()
+
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
@@ -1068,11 +1099,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
- if latent_scale_mode is not None:
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
+ if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+ img2img_sampler_name = 'DDIM'
+
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+
+ if self.latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -1081,7 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = []
@@ -1100,6 +1137,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device)
decoded_samples = 2. * decoded_samples - 1.
+ decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
@@ -1107,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- img2img_sampler_name = self.hr_sampler_name or self.sampler_name
-
- if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
- img2img_sampler_name = 'DDIM'
-
- self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
- x = None
devices.torch_gc()
if not self.disable_extra_networks:
@@ -1138,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+
self.is_hr_pass = False
- return samples
+ return decoded_samples
def close(self):
super().close()
@@ -1179,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_c is not None:
return
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+ hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
+ hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self):
super().setup_conds()
@@ -1188,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = None
self.hr_c = None
- if self.enable_hr:
+ if self.enable_hr and self.hr_checkpoint_info is None:
if shared.opts.hires_fix_use_firstpass_conds:
self.calculate_hr_conds()