aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-03-02 06:54:11 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2024-03-02 06:55:03 +0300
commit141a17e9693065c33a2b1d30f04a0083bb687775 (patch)
treea0ef7f513a7f8d5fadcb126135b9b565947ab8af /modules/sd_models.py
parentda67afe5f68497a04d1fd9173bbd256b73d9d251 (diff)
style changes for #14979
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py70
1 files changed, 41 insertions, 29 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index db72e120..747fc39e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -552,36 +552,48 @@ def repair_config(sd_config):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
+
+def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+
def apply_alpha_schedule_override(sd_model, p=None):
- def rescale_zero_terminal_snr_abar(alphas_cumprod):
- alphas_bar_sqrt = alphas_cumprod.sqrt()
-
- # Store old values.
- alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
-
- # Shift so the last timestep is zero.
- alphas_bar_sqrt -= (alphas_bar_sqrt_T)
-
- # Scale so the first timestep is back to the old value.
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
-
- # Convert alphas_bar_sqrt to betas
- alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
- alphas_bar[-1] = 4.8973451890853435e-08
- return alphas_bar
-
- if hasattr(sd_model, 'alphas_cumprod') and hasattr(sd_model, 'alphas_cumprod_original'):
- sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
-
- if opts.use_downcasted_alpha_bar:
- if p is not None:
- p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
- sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
- if opts.sd_noise_schedule == "Zero Terminal SNR":
- if p is not None:
- p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
- sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
+ """
+ Applies an override to the alpha schedule of the model according to settings.
+ - downcasts the alpha schedule to half precision
+ - rescales the alpha schedule to have zero terminal SNR
+ """
+
+ if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
+ return
+
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ if p is not None:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
+
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ if p is not None:
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
+
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'