aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordrhead <1313496+drhead@users.noreply.github.com>2023-11-29 17:38:53 -0500
committerGitHub <noreply@github.com>2023-11-29 17:38:53 -0500
commitb25c126ccdbc4da22ade46597a9addf808998989 (patch)
tree0e04d05b3b57371da64f5abe9b53ce646f7520f0
parentf0f100e67b78f686dc73cf3c8cad422e45cc9b8a (diff)
Protect alphas_cumprod from downcasting
-rw-r--r--modules/sd_models.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 841402e8..de80a493 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -387,7 +387,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
+ alphas_cumprod = model.alphas_cumprod
+ model.alphas_cumprod = None
model.half()
+ model.alphas_cumprod = alphas_cumprod
+ model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
@@ -642,6 +646,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
weight_dtype_conversion = {
'first_stage_model': None,
+ 'alphas_cumprod': None,
'': torch.float16,
}