aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_vae_taesd.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r--modules/sd_vae_taesd.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 5e8496e8..5bf7c76e 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -8,9 +8,9 @@ import os
import torch
import torch.nn as nn
-from modules import devices, paths_internal
+from modules import devices, paths_internal, shared
-sd_vae_taesd = None
+sd_vae_taesd_models = {}
def conv(n_in, n_out, **kwargs):
@@ -61,9 +61,7 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
-def download_model(model_path):
- model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
-
+def download_model(model_path, model_url):
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
@@ -72,17 +70,19 @@ def download_model(model_path):
def model():
- global sd_vae_taesd
+ model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
+ loaded_model = sd_vae_taesd_models.get(model_name)
- if sd_vae_taesd is None:
- model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
- download_model(model_path)
+ if loaded_model is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
+ download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path):
- sd_vae_taesd = TAESD(model_path)
- sd_vae_taesd.eval()
- sd_vae_taesd.to(devices.device, devices.dtype)
+ loaded_model = TAESD(model_path)
+ loaded_model.eval()
+ loaded_model.to(devices.device, devices.dtype)
+ sd_vae_taesd_models[model_name] = loaded_model
else:
raise FileNotFoundError('TAESD model not found')
- return sd_vae_taesd.decoder
+ return loaded_model.decoder