aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_vae.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-19 10:39:51 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-19 10:39:51 +0300
commit0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135 (patch)
tree0e81a16c42f716c704d6aa63458f7c3c1894c56e /modules/sd_vae.py
parentc7e50425f63c07242068f8dcccce70a4ef28a17f (diff)
allow baking in VAE in checkpoint merger tab
do not save config if it's the default for checkpoint merger tab change file naming scheme for checkpoint merger tab allow just saving A without any merging for checkpoint merger tab some stylistic changes for UI in checkpoint merger tab
Diffstat (limited to 'modules/sd_vae.py')
-rw-r--r--modules/sd_vae.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index da1bf15c..4ce238b8 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file):
return None, None
+def load_vae_dict(filename, map_location):
+ vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ return vae_dict_1
+
+
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file
# save_settings = False
@@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
- vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled: