aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorcatboxanon <122327233+catboxanon@users.noreply.github.com>2023-11-29 18:33:32 -0500
committercatboxanon <122327233+catboxanon@users.noreply.github.com>2023-11-29 18:33:32 -0500
commitde79597ab9894965e3702939b8536ec3dcc3c859 (patch)
tree3572b8c917b2f4ea4fb232647f470c0f8b43723c /modules
parentffa7f8201d849636bb327b3b40298e7c169ff204 (diff)
Only apply ztSNR related code if alphas_cumprod exists
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/modules/processing.py b/modules/processing.py
index f3883d5b..7e73d7e2 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -882,15 +882,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
alphas_bar[-1] = 4.8973451890853435e-08
return alphas_bar
- p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
-
- if opts.use_downcasted_alpha_bar:
- p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
- p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
- if opts.sd_noise_schedule == "Zero Terminal SNR":
- p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
- print("rescaling noise schedule for zero snr")
- p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
+ if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ print("rescaling noise schedule for zero snr")
+ p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)