aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorSakura-Luna <53183413+Sakura-Luna@users.noreply.github.com>2023-05-14 12:42:44 +0800
committerSakura-Luna <53183413+Sakura-Luna@users.noreply.github.com>2023-05-14 14:06:01 +0800
commite14b586d0494d6c5cc3cbc45b5fa00c03d052443 (patch)
tree807b3e771ef465654b672956d09d94af525d14ab /modules
parentb08500cec8a791ef20082628b49b17df833f5dda (diff)
Add Tiny AE live preview
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_samplers_common.py21
-rw-r--r--modules/sd_vae_taesd.py76
-rw-r--r--modules/shared.py2
3 files changed, 90 insertions, 9 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index bc074238..d3dc130c 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
+from modules import devices, processing, images, sd_vae_approx, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
- if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
- elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ if approximation == 1:
+ x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample)
+ x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1)
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ if approximation == 3:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 2:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
new file mode 100644
index 00000000..ccc97959
--- /dev/null
+++ b/modules/sd_vae_taesd.py
@@ -0,0 +1,76 @@
+"""
+Tiny AutoEncoder for Stable Diffusion
+(DNN for encoding / decoding SD's latent space)
+
+https://github.com/madebyollin/taesd
+"""
+import os
+import torch
+import torch.nn as nn
+
+from modules import devices, paths_internal
+
+sd_vae_taesd = None
+
+
+def conv(n_in, n_out, **kwargs):
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+
+class Clamp(nn.Module):
+ @staticmethod
+ def forward(x):
+ return torch.tanh(x / 3) * 3
+
+
+class Block(nn.Module):
+ def __init__(self, n_in, n_out):
+ super().__init__()
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.fuse = nn.ReLU()
+
+ def forward(self, x):
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+def decoder():
+ return nn.Sequential(
+ Clamp(), conv(4, 64), nn.ReLU(),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), conv(64, 3),
+ )
+
+
+class TAESD(nn.Module):
+ latent_magnitude = 2
+ latent_shift = 0.5
+
+ def __init__(self, decoder_path="taesd_decoder.pth"):
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
+ super().__init__()
+ self.decoder = decoder()
+ self.decoder.load_state_dict(
+ torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+
+ @staticmethod
+ def unscale_latents(x):
+ """[0, 1] -> raw latents"""
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+
+def decode():
+ global sd_vae_taesd
+
+ if sd_vae_taesd is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth")
+ if os.path.exists(model_path):
+ sd_vae_taesd = TAESD(model_path)
+ sd_vae_taesd.eval()
+ sd_vae_taesd.to(devices.device, devices.dtype)
+ else:
+ raise FileNotFoundError('Tiny AE mdoel not found')
+
+ return sd_vae_taesd.decoder
diff --git a/modules/shared.py b/modules/shared.py
index 4631965b..6760a900 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
}))