aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/processing.py16
-rw-r--r--modules/sd_samplers_common.py44
-rw-r--r--modules/sd_vae_approx.py2
-rw-r--r--modules/sd_vae_taesd.py52
-rw-r--r--modules/shared.py2
5 files changed, 89 insertions, 27 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 8086a2b0..aa6d4d2a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -16,6 +16,7 @@ from typing import Any, Dict, List
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack
+from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
@@ -30,7 +31,6 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
-decode_first_stage = sd_samplers_common.decode_first_stage
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+ image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
@@ -203,7 +203,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
@@ -216,7 +216,7 @@ class StableDiffusionProcessing:
return conditioning
def edit_image_conditioning(self, source_image):
- conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
return conditioning_image
@@ -1099,9 +1099,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
@@ -1339,10 +1338,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
- image = 2. * image - 1.
- image = image.to(shared.device, dtype=devices.dtype_vae)
-
- self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
devices.torch_gc()
if self.resize_mode == 3:
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index b3d344e7..42a29fc9 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -23,19 +23,29 @@ def setup_img2img_steps(p, steps=None):
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
-def single_sample_to_image(sample, approximation=None):
+def samples_to_images_tensor(sample, approximation=None, model=None):
+ '''latents -> images [-1, 1]'''
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
+ x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
x_sample = sample * 1.5
- x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
+ x_sample = x_sample * 2 - 1
else:
- x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
+ if model is None:
+ model = shared.sd_model
+ x_sample = model.decode_first_stage(sample)
+
+ return x_sample
+
+
+def single_sample_to_image(sample, approximation=None):
+ x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -45,9 +55,9 @@ def single_sample_to_image(sample, approximation=None):
def decode_first_stage(model, x):
- x = model.decode_first_stage(x.to(devices.dtype_vae))
-
- return x
+ x = x.to(devices.dtype_vae)
+ approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
+ return samples_to_images_tensor(x, approx_index, model)
def sample_to_image(samples, index=0, approximation=None):
@@ -58,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+def images_tensor_to_samples(image, approximation=None, model=None):
+ '''image[0, 1] -> latent'''
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
+
+ if approximation == 3:
+ image = image.to(devices.device, devices.dtype)
+ x_latent = sd_vae_taesd.encoder_model()(image)
+ else:
+ if model is None:
+ model = shared.sd_model
+ image = image.to(shared.device, dtype=devices.dtype_vae)
+ image = image * 2 - 1
+ x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+
+ return x_latent
+
+
def store_latent(decoded):
state.current_latent = decoded
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 86bd658a..3965e223 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -81,6 +81,6 @@ def cheap_approximation(sample):
coefs = torch.tensor(coeffs).to(sample.device)
- x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+ x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
return x_sample
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 5bf7c76e..808eb362 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -44,7 +44,17 @@ def decoder():
)
-class TAESD(nn.Module):
+def encoder():
+ return nn.Sequential(
+ conv(3, 64), Block(64, 64),
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+ conv(64, 4),
+ )
+
+
+class TAESDDecoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
@@ -55,21 +65,28 @@ class TAESD(nn.Module):
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
- @staticmethod
- def unscale_latents(x):
- """[0, 1] -> raw latents"""
- return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+class TAESDEncoder(nn.Module):
+ latent_magnitude = 3
+ latent_shift = 0.5
+
+ def __init__(self, encoder_path="taesd_encoder.pth"):
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
+ super().__init__()
+ self.encoder = encoder()
+ self.encoder.load_state_dict(
+ torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
def download_model(model_path, model_url):
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
- print(f'Downloading TAESD decoder to: {model_path}')
+ print(f'Downloading TAESD model to: {model_path}')
torch.hub.download_url_to_file(model_url, model_path)
-def model():
+def decoder_model():
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)
@@ -78,7 +95,7 @@ def model():
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path):
- loaded_model = TAESD(model_path)
+ loaded_model = TAESDDecoder(model_path)
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
@@ -86,3 +103,22 @@ def model():
raise FileNotFoundError('TAESD model not found')
return loaded_model.decoder
+
+
+def encoder_model():
+ model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
+ loaded_model = sd_vae_taesd_models.get(model_name)
+
+ if loaded_model is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
+ download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
+
+ if os.path.exists(model_path):
+ loaded_model = TAESDEncoder(model_path)
+ loaded_model.eval()
+ loaded_model.to(devices.device, devices.dtype)
+ sd_vae_taesd_models[model_name] = loaded_model
+ else:
+ raise FileNotFoundError('TAESD model not found')
+
+ return loaded_model.encoder
diff --git a/modules/shared.py b/modules/shared.py
index cec030f7..61ba9347 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -430,6 +430,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
+ "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img or inpaint mask)"),
+ "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {