aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_vae_taesd.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_vae_taesd.py')
-rw-r--r--modules/sd_vae_taesd.py18
1 files changed, 15 insertions, 3 deletions
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 927a7298..d23812ef 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -61,16 +61,28 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
-def decode():
+def download_model(model_path):
+ model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
+
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+ print(f'Downloading TAESD decoder to: {model_path}')
+ torch.hub.download_url_to_file(model_url, model_path)
+
+
+def model():
global sd_vae_taesd
if sd_vae_taesd is None:
- model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth")
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
+ download_model(model_path)
+
if os.path.exists(model_path):
sd_vae_taesd = TAESD(model_path)
sd_vae_taesd.eval()
sd_vae_taesd.to(devices.device, devices.dtype)
else:
- raise FileNotFoundError('Tiny AE model not found')
+ raise FileNotFoundError('TAESD model not found')
return sd_vae_taesd.decoder