aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_samplers.py29
-rw-r--r--modules/sd_vae_approx.py58
-rw-r--r--modules/shared.py6
3 files changed, 77 insertions, 16 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 27ef4ff8..177b5338 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -9,7 +9,7 @@ import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing, images
+from modules import prompt_parser, devices, processing, images, sd_vae_approx
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -106,28 +106,31 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def single_sample_to_image(sample, approximation=False):
- if approximation:
- # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
- coefs = torch.tensor(
- [[ 0.298, 0.207, 0.208],
- [ 0.187, 0.286, 0.173],
- [-0.158, 0.189, 0.264],
- [-0.184, -0.271, -0.473]]).to(sample.device)
- x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
-def sample_to_image(samples, index=0, approximation=False):
+def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
-def samples_to_image_grid(samples, approximation=False):
+def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
@@ -136,7 +139,7 @@ def store_latent(decoded):
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
- shared.state.current_image = sample_to_image(decoded, approximation=opts.show_progress_approximate)
+ shared.state.current_image = sample_to_image(decoded)
class InterruptedException(BaseException):
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
new file mode 100644
index 00000000..0a58542d
--- /dev/null
+++ b/modules/sd_vae_approx.py
@@ -0,0 +1,58 @@
+import os
+
+import torch
+from torch import nn
+from modules import devices, paths
+
+sd_vae_approx_model = None
+
+
+class VAEApprox(nn.Module):
+ def __init__(self):
+ super(VAEApprox, self).__init__()
+ self.conv1 = nn.Conv2d(4, 8, (7, 7))
+ self.conv2 = nn.Conv2d(8, 16, (5, 5))
+ self.conv3 = nn.Conv2d(16, 32, (3, 3))
+ self.conv4 = nn.Conv2d(32, 64, (3, 3))
+ self.conv5 = nn.Conv2d(64, 32, (3, 3))
+ self.conv6 = nn.Conv2d(32, 16, (3, 3))
+ self.conv7 = nn.Conv2d(16, 8, (3, 3))
+ self.conv8 = nn.Conv2d(8, 3, (3, 3))
+
+ def forward(self, x):
+ extra = 11
+ x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+ x = nn.functional.pad(x, (extra, extra, extra, extra))
+
+ for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
+ x = layer(x)
+ x = nn.functional.leaky_relu(x, 0.1)
+
+ return x
+
+
+def model():
+ global sd_vae_approx_model
+
+ if sd_vae_approx_model is None:
+ sd_vae_approx_model = VAEApprox()
+ sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+ sd_vae_approx_model.eval()
+ sd_vae_approx_model.to(devices.device, devices.dtype)
+
+ return sd_vae_approx_model
+
+
+def cheap_approximation(sample):
+ # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
+
+ coefs = torch.tensor([
+ [0.298, 0.207, 0.208],
+ [0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]).to(sample.device)
+
+ x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+
+ return x_sample
diff --git a/modules/shared.py b/modules/shared.py
index eb3e5aec..3cc3c724 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -212,9 +212,9 @@ class State:
import modules.sd_samplers
if opts.show_progress_grid:
- self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent, approximation=opts.show_progress_approximate)
+ self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
else:
- self.current_image = modules.sd_samplers.sample_to_image(self.current_latent, approximation=opts.show_progress_approximate)
+ self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step
@@ -392,7 +392,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_approximate": OptionInfo(False, "Calculate small previews using fast linear approximation instead of VAE"),
+ "show_progress_type": OptionInfo("Full", "Image creation progress mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),