From a5bbcd215304e0c83ab2b9fe7f172f88536d7629 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 19:56:09 +0300 Subject: fix bug with "Ignore selected VAE for..." option completely disabling VAE election rework VAE resolving code to be more simple --- modules/sd_vae.py | 194 ++++++++++++++++++++++-------------------------------- 1 file changed, 77 insertions(+), 117 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0a49daa1..6ea92711 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -9,23 +9,9 @@ import glob from copy import deepcopy -model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) -vae_dir = "VAE" -vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) - - +vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - - -default_vae_dict = {"auto": "auto", "None": None, None: None} -default_vae_list = ["auto", "None"] - - -default_vae_values = [default_vae_dict[x] for x in default_vae_list] -vae_dict = dict(default_vae_dict) -vae_list = list(default_vae_list) -first_load = True +vae_dict = {} base_vae = None @@ -64,100 +50,69 @@ def restore_base_vae(model): def get_filename(filepath): - return os.path.splitext(os.path.basename(filepath))[0] - - -def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list - res = {} - candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), + return os.path.basename(filepath) + + +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), ] - if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): - candidates.append(shared.cmd_opts.vae_path) + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + for filepath in candidates: name = get_filename(filepath) - res[name] = filepath - vae_list.clear() - vae_list.extend(default_vae_list) - vae_list.extend(list(res.keys())) - vae_dict.clear() - vae_dict.update(res) - vae_dict.update(default_vae_dict) - return vae_list - - -def get_vae_from_settings(vae_file="auto"): - # else, we load from settings, if not set to be default - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print(f"Selected VAE doesn't exist: {vae_file}") - return vae_file - - -def resolve_vae(checkpoint_file=None, vae_file="auto"): - global first_load, vae_dict, vae_list - - # if vae_file argument is provided, it takes priority, but not saved - if vae_file and vae_file not in default_vae_list: - if not os.path.isfile(vae_file): - print(f"VAE provided as function argument doesn't exist: {vae_file}") - vae_file = "auto" - # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported - if first_load and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - shared.opts.data['sd_vae'] = get_filename(vae_file) - else: - print(f"VAE provided as command line argument doesn't exist: {vae_file}") - # fallback to selector in settings, if vae selector not set to act as default fallback - if not shared.opts.sd_vae_as_default: - vae_file = get_vae_from_settings(vae_file) - # vae-path cmd arg takes priority for auto - if vae_file == "auto" and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - print(f"Using VAE provided as command line argument: {vae_file}") - # if still not found, try look for ".vae.pt" beside model - model_path = os.path.splitext(checkpoint_file)[0] - if vae_file == "auto": - vae_file_try = model_path + ".vae.pt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.ckpt" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.ckpt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.safetensors" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.safetensors" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # No more fallbacks for auto - if vae_file == "auto": - vae_file = None - # Last check, just because - if vae_file and not os.path.exists(vae_file): - vae_file = None - - return vae_file - - -def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file + vae_dict[name] = filepath + + +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None + + +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' + + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + return vae_near_checkpoint, 'found near the checkpoint' + + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if shared.opts.sd_vae != "Automatic": + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -165,12 +120,12 @@ def load_vae(model, vae_file=None): if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache - print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) _load_vae_dict(model, checkpoints_loaded[vae_file]) else: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) @@ -191,14 +146,12 @@ def load_vae(model, vae_file=None): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file - vae_list.append(vae_opt) + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file - first_load = False - # don't call this from outside def _load_vae_dict(model, vae_dict_1): @@ -211,7 +164,10 @@ def clear_loaded_vae(): loaded_vae_file = None -def reload_vae_weights(sd_model=None, vae_file="auto"): +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack if not sd_model: @@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_info = sd_model.sd_checkpoint_info checkpoint_file = checkpoint_info.filename - vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" if loaded_vae_file == vae_file: return @@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file) + load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) @@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("VAE Weights loaded.") + print("VAE weights loaded.") return sd_model -- cgit v1.2.1