aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_vae_approx.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_vae_approx.py')
-rw-r--r--modules/sd_vae_approx.py22
1 files changed, 16 insertions, 6 deletions
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index b348f3ae..86bd658a 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -64,12 +64,22 @@ def model():
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
- coefs = torch.tensor([
- [0.298, 0.207, 0.208],
- [0.187, 0.286, 0.173],
- [-0.158, 0.189, 0.264],
- [-0.184, -0.271, -0.473],
- ]).to(sample.device)
+ if shared.sd_model.is_sdxl:
+ coeffs = [
+ [ 0.3448, 0.4168, 0.4395],
+ [-0.1953, -0.0290, 0.0250],
+ [ 0.1074, 0.0886, -0.0163],
+ [-0.3730, -0.2499, -0.2088],
+ ]
+ else:
+ coeffs = [
+ [ 0.298, 0.207, 0.208],
+ [ 0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]
+
+ coefs = torch.tensor(coeffs).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)