aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_vae_approx.py
diff options
context:
space:
mode:
authorLeon Feng <523684+leon0707@users.noreply.github.com>2023-07-18 04:24:14 -0400
committerGitHub <noreply@github.com>2023-07-18 04:24:14 -0400
commita3730bd9becd2f1f5d209885b694b0dec178d110 (patch)
tree8ac9948d89606f7519df786f07f6ddb93c3d2720 /modules/sd_vae_approx.py
parentd6668347c8b85b11b696ac56777cc396e34ee1f9 (diff)
parent871b8687a82bb2ca907d8a49c87aed7635b8fc33 (diff)
Merge branch 'dev' into fix-11805
Diffstat (limited to 'modules/sd_vae_approx.py')
-rw-r--r--modules/sd_vae_approx.py59
1 files changed, 42 insertions, 17 deletions
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index e2f00468..86bd658a 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -2,9 +2,9 @@ import os
import torch
from torch import nn
-from modules import devices, paths
+from modules import devices, paths, shared
-sd_vae_approx_model = None
+sd_vae_approx_models = {}
class VAEApprox(nn.Module):
@@ -31,30 +31,55 @@ class VAEApprox(nn.Module):
return x
+def download_model(model_path, model_url):
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+ print(f'Downloading VAEApprox model to: {model_path}')
+ torch.hub.download_url_to_file(model_url, model_path)
+
+
def model():
- global sd_vae_approx_model
+ model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
+ loaded_model = sd_vae_approx_models.get(model_name)
- if sd_vae_approx_model is None:
- model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
- sd_vae_approx_model = VAEApprox()
+ if loaded_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
if not os.path.exists(model_path):
- model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
- sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
- sd_vae_approx_model.eval()
- sd_vae_approx_model.to(devices.device, devices.dtype)
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
+
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
+ download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
+
+ loaded_model = VAEApprox()
+ loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+ loaded_model.eval()
+ loaded_model.to(devices.device, devices.dtype)
+ sd_vae_approx_models[model_name] = loaded_model
- return sd_vae_approx_model
+ return loaded_model
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
- coefs = torch.tensor([
- [0.298, 0.207, 0.208],
- [0.187, 0.286, 0.173],
- [-0.158, 0.189, 0.264],
- [-0.184, -0.271, -0.473],
- ]).to(sample.device)
+ if shared.sd_model.is_sdxl:
+ coeffs = [
+ [ 0.3448, 0.4168, 0.4395],
+ [-0.1953, -0.0290, 0.0250],
+ [ 0.1074, 0.0886, -0.0163],
+ [-0.3730, -0.2499, -0.2088],
+ ]
+ else:
+ coeffs = [
+ [ 0.298, 0.207, 0.208],
+ [ 0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]
+
+ coefs = torch.tensor(coeffs).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)