aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py45
1 files changed, 38 insertions, 7 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 6567b3cf..a74a5302 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -14,7 +14,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -538,6 +538,40 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
+ samples = []
+
+ for i in range(batch.shape[0]):
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if check_for_nans:
+ try:
+ devices.test_for_nans(sample, "vae")
+ except devices.NansException as e:
+ if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ raise e
+
+ errors.print_error_explanation(
+ "A tensor with all NaNs was produced in VAE.\n"
+ "Web UI will now convert VAE into 32-bit float and retry.\n"
+ "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
+ "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ )
+
+ devices.dtype_vae = torch.float32
+ model.first_stage_model.to(devices.dtype_vae)
+ batch = batch.to(devices.dtype_vae)
+
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if target_device is not None:
+ sample = sample.to(target_device)
+
+ samples.append(sample)
+
+ return samples
+
+
def decode_first_stage(model, x):
x = model.decode_first_stage(x.to(devices.dtype_vae))
@@ -587,7 +621,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
+ "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -758,10 +792,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
- for x in x_samples_ddim:
- devices.test_for_nans(x, "vae")
-
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -1029,7 +1060,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
image = sd_samplers.sample_to_image(image, index, approximation=0)
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
if latent_scale_mode is not None:
for i in range(samples.shape[0]):