aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-01 14:45:12 +0300
committerGitHub <noreply@github.com>2024-01-01 14:45:12 +0300
commit267fd5d76b00b0c22edffa83c1a078680ba8b42f (patch)
treec4092b8ec7430f15aaac7d9f8a0fa2199de28140 /modules/processing.py
parentd613cd17c72c753bd1e314dff74dc22d9a949374 (diff)
parent5381405eaa1e809e5cfb97522bd4c19d3c946079 (diff)
Merge pull request #14145 from drhead/zero-terminal-snr
Implement zero terminal SNR noise schedule option
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/modules/processing.py b/modules/processing.py
index b30df60d..846e4796 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -898,6 +898,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+ if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)