aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_vae.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_vae.py')
-rw-r--r--modules/sd_vae.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index da1bf15c..4ce238b8 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file):
return None, None
+def load_vae_dict(filename, map_location):
+ vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ return vae_dict_1
+
+
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file
# save_settings = False
@@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
- vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled: