diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-05 07:52:29 +0300 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-05 07:52:29 +0300 |
commit | 22ecb78b51f7e6f0234cbc0efbde4ee9a2dc466f (patch) | |
tree | 69e3bf4d53f4113f192116c252a2e410bb5b1f90 /modules/sd_vae.py | |
parent | 390bffa81b747a7eb38ac7a0cd6dfb9fcc388151 (diff) | |
parent | 0ae2767ae6bb775de448b0d8cda1806edb2aef67 (diff) |
Merge branch 'dev' into multiple_loaded_models
Diffstat (limited to 'modules/sd_vae.py')
-rw-r--r-- | modules/sd_vae.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e4ff2994..0bd5e19b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,6 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -16,6 +16,7 @@ checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: return base_vae @@ -50,6 +51,7 @@ def get_filename(filepath): def refresh_vae_list(): + global vae_dict vae_dict.clear() paths = [ @@ -83,6 +85,8 @@ def refresh_vae_list(): name = get_filename(filepath) vae_dict[name] = filepath + vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))) + def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] @@ -97,6 +101,16 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return None, None + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return vae_from_metadata, "from user metadata" + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) |