aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/sd_vae.py31
-rw-r--r--modules/shared.py1
2 files changed, 26 insertions, 6 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 3856418e..ac71d62d 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,5 +1,6 @@
import torch
import os
+import collections
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
@@ -30,6 +31,7 @@ base_vae = None
loaded_vae_file = None
checkpoint_info = None
+checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
+ cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
+
if vae_file:
- assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
- print(f"Loading VAE weights from: {vae_file}")
- store_base_vae(model)
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- _load_vae_dict(model, vae_dict_1)
+ if cache_enabled and vae_file in checkpoints_loaded:
+ # use vae checkpoint cache
+ print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
+ store_base_vae(model)
+ _load_vae_dict(model, checkpoints_loaded[vae_file])
+ else:
+ assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
+ print(f"Loading VAE weights from: {vae_file}")
+ store_base_vae(model)
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ _load_vae_dict(model, vae_dict_1)
+
+ if cache_enabled:
+ # cache newly loaded vae
+ checkpoints_loaded[vae_file] = vae_dict_1.copy()
+
+ # clean up cache if limit is reached
+ if cache_enabled:
+ while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
+ checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it
# It will be removed on refresh though
diff --git a/modules/shared.py b/modules/shared.py
index d4ddeea0..671d30e1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),