aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/LDSR
diff options
context:
space:
mode:
authorZac Liu <liuguang@baai.ac.cn>2022-12-06 09:16:15 +0800
committerGitHub <noreply@github.com>2022-12-06 09:16:15 +0800
commit3ebf977a6e4f478ab918e44506974beee32da276 (patch)
treef68456207e5cd78718ec1e9c588ecdc22d568d81 /extensions-builtin/LDSR
parent231fb72872191ffa8c446af1577c9003b3d19d4f (diff)
parent44c46f0ed395967cd3830dd481a2db759fda5b3b (diff)
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'extensions-builtin/LDSR')
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py230
-rw-r--r--extensions-builtin/LDSR/preload.py6
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py64
-rw-r--r--extensions-builtin/LDSR/sd_hijack_autoencoder.py286
4 files changed, 586 insertions, 0 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
new file mode 100644
index 00000000..90e0a2f0
--- /dev/null
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -0,0 +1,230 @@
+import gc
+import time
+import warnings
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from einops import rearrange, repeat
+from omegaconf import OmegaConf
+
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.util import instantiate_from_config, ismap
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+# Create LDSR Class
+class LDSR:
+ def load_model_from_config(self, half_attention):
+ print(f"Loading model from {self.modelPath}")
+ pl_sd = torch.load(self.modelPath, map_location="cpu")
+ sd = pl_sd["state_dict"]
+ config = OmegaConf.load(self.yamlPath)
+ model = instantiate_from_config(config.model)
+ model.load_state_dict(sd, strict=False)
+ model.cuda()
+ if half_attention:
+ model = model.half()
+
+ model.eval()
+ return {"model": model}
+
+ def __init__(self, model_path, yaml_path):
+ self.modelPath = model_path
+ self.yamlPath = yaml_path
+
+ @staticmethod
+ def run(model, selected_path, custom_steps, eta):
+ example = get_cond(selected_path)
+
+ n_runs = 1
+ guider = None
+ ckwargs = None
+ ddim_use_x0_pred = False
+ temperature = 1.
+ eta = eta
+ custom_shape = None
+
+ height, width = example["image"].shape[1:3]
+ split_input = height >= 128 and width >= 128
+
+ if split_input:
+ ks = 128
+ stride = 64
+ vqf = 4 #
+ model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
+ "vqf": vqf,
+ "patch_distributed_vq": True,
+ "tie_braker": False,
+ "clip_max_weight": 0.5,
+ "clip_min_weight": 0.01,
+ "clip_max_tie_weight": 0.5,
+ "clip_min_tie_weight": 0.01}
+ else:
+ if hasattr(model, "split_input_params"):
+ delattr(model, "split_input_params")
+
+ x_t = None
+ logs = None
+ for n in range(n_runs):
+ if custom_shape is not None:
+ x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
+ x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
+
+ logs = make_convolutional_sample(example, model,
+ custom_steps=custom_steps,
+ eta=eta, quantize_x0=False,
+ custom_shape=custom_shape,
+ temperature=temperature, noise_dropout=0.,
+ corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
+ ddim_use_x0_pred=ddim_use_x0_pred
+ )
+ return logs
+
+ def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
+ model = self.load_model_from_config(half_attention)
+
+ # Run settings
+ diffusion_steps = int(steps)
+ eta = 1.0
+
+ down_sample_method = 'Lanczos'
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ im_og = image
+ width_og, height_og = im_og.size
+ # If we can adjust the max upscale size, then the 4 below should be our variable
+ down_sample_rate = target_scale / 4
+ wd = width_og * down_sample_rate
+ hd = height_og * down_sample_rate
+ width_downsampled_pre = int(np.ceil(wd))
+ height_downsampled_pre = int(np.ceil(hd))
+
+ if down_sample_rate != 1:
+ print(
+ f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
+ im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
+ else:
+ print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
+
+ # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
+ pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
+ im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
+
+ logs = self.run(model["model"], im_padded, diffusion_steps, eta)
+
+ sample = logs["sample"]
+ sample = sample.detach().cpu()
+ sample = torch.clamp(sample, -1., 1.)
+ sample = (sample + 1.) / 2. * 255
+ sample = sample.numpy().astype(np.uint8)
+ sample = np.transpose(sample, (0, 2, 3, 1))
+ a = Image.fromarray(sample[0])
+
+ # remove padding
+ a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
+
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+ return a
+
+
+def get_cond(selected_path):
+ example = dict()
+ up_f = 4
+ c = selected_path.convert('RGB')
+ c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
+ c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
+ antialias=True)
+ c_up = rearrange(c_up, '1 c h w -> 1 h w c')
+ c = rearrange(c, '1 c h w -> 1 h w c')
+ c = 2. * c - 1.
+
+ c = c.to(torch.device("cuda"))
+ example["LR_image"] = c
+ example["image"] = c_up
+
+ return example
+
+
+@torch.no_grad()
+def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
+ mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
+ corrector_kwargs=None, x_t=None
+ ):
+ ddim = DDIMSampler(model)
+ bs = shape[0]
+ shape = shape[1:]
+ print(f"Sampling with eta = {eta}; steps: {steps}")
+ samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
+ normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
+ mask=mask, x0=x0, temperature=temperature, verbose=False,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs, x_t=x_t)
+
+ return samples, intermediates
+
+
+@torch.no_grad()
+def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
+ corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
+ log = dict()
+
+ z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=not (hasattr(model, 'split_input_params')
+ and model.cond_stage_key == 'coordinates_bbox'),
+ return_original_cond=True)
+
+ if custom_shape is not None:
+ z = torch.randn(custom_shape)
+ print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
+
+ z0 = None
+
+ log["input"] = x
+ log["reconstruction"] = xrec
+
+ if ismap(xc):
+ log["original_conditioning"] = model.to_rgb(xc)
+ if hasattr(model, 'cond_stage_key'):
+ log[model.cond_stage_key] = model.to_rgb(xc)
+
+ else:
+ log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
+ if model.cond_stage_model:
+ log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
+ if model.cond_stage_key == 'class_label':
+ log[model.cond_stage_key] = xc[model.cond_stage_key]
+
+ with model.ema_scope("Plotting"):
+ t0 = time.time()
+
+ sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
+ eta=eta,
+ quantize_x0=quantize_x0, mask=None, x0=z0,
+ temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
+ x_t=x_T)
+ t1 = time.time()
+
+ if ddim_use_x0_pred:
+ sample = intermediates['pred_x0'][-1]
+
+ x_sample = model.decode_first_stage(sample)
+
+ try:
+ x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
+ log["sample_noquant"] = x_sample_noquant
+ log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
+ except:
+ pass
+
+ log["sample"] = x_sample
+ log["time"] = t1 - t0
+
+ return log
diff --git a/extensions-builtin/LDSR/preload.py b/extensions-builtin/LDSR/preload.py
new file mode 100644
index 00000000..d746007c
--- /dev/null
+++ b/extensions-builtin/LDSR/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
new file mode 100644
index 00000000..1cef29a4
--- /dev/null
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -0,0 +1,64 @@
+import os
+import sys
+import traceback
+
+from basicsr.utils.download_util import load_file_from_url
+
+from modules.upscaler import Upscaler, UpscalerData
+from ldsr_model_arch import LDSR
+from modules import shared, script_callbacks
+import sd_hijack_autoencoder
+
+
+class UpscalerLDSR(Upscaler):
+ def __init__(self, user_path):
+ self.name = "LDSR"
+ self.user_path = user_path
+ self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
+ self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
+ super().__init__()
+ scaler_data = UpscalerData("LDSR", None, self)
+ self.scalers = [scaler_data]
+
+ def load_model(self, path: str):
+ # Remove incorrect project.yaml file if too big
+ yaml_path = os.path.join(self.model_path, "project.yaml")
+ old_model_path = os.path.join(self.model_path, "model.pth")
+ new_model_path = os.path.join(self.model_path, "model.ckpt")
+ if os.path.exists(yaml_path):
+ statinfo = os.stat(yaml_path)
+ if statinfo.st_size >= 10485760:
+ print("Removing invalid LDSR YAML file.")
+ os.remove(yaml_path)
+ if os.path.exists(old_model_path):
+ print("Renaming model from model.pth to model.ckpt")
+ os.rename(old_model_path, new_model_path)
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="model.ckpt", progress=True)
+ yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
+ file_name="project.yaml", progress=True)
+
+ try:
+ return LDSR(model, yaml)
+
+ except Exception:
+ print("Error importing LDSR:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ def do_upscale(self, img, path):
+ ldsr = self.load_model(path)
+ if ldsr is None:
+ print("NO LDSR!")
+ return img
+ ddim_steps = shared.opts.ldsr_steps
+ return ldsr.super_resolution(img, ddim_steps, self.scale)
+
+
+def on_ui_settings():
+ import gradio as gr
+
+ shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
new file mode 100644
index 00000000..8e03c7f8
--- /dev/null
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -0,0 +1,286 @@
+# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
+# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
+# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
+
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.util import instantiate_from_config
+
+import ldm.models.autoencoder
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ batch_resize_range=None,
+ scheduler_config=None,
+ lr_g_factor=1.0,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ use_ema=False
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.n_embed = n_embed
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap,
+ sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ self.batch_resize_range = batch_resize_range
+ if self.batch_resize_range is not None:
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.scheduler_config = scheduler_config
+ self.lr_g_factor = lr_g_factor
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ print(f"Unexpected Keys: {unexpected}")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input, return_pred_indices=False):
+ quant, diff, (_,_,ind) = self.encode(input)
+ dec = self.decode(quant)
+ if return_pred_indices:
+ return dec, diff, ind
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ if self.batch_resize_range is not None:
+ lower_size = self.batch_resize_range[0]
+ upper_size = self.batch_resize_range[1]
+ if self.global_step <= 4:
+ # do the first few batches with max size to avoid later oom
+ new_resize = upper_size
+ else:
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
+ if new_resize != x.shape[2]:
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
+ x = x.detach()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ # https://github.com/pytorch/pytorch/issues/37142
+ # try not to fool the heuristics
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train",
+ predicted_indices=ind)
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, suffix=""):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log(f"val{suffix}/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log(f"val{suffix}/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
+ del log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr_d = self.learning_rate
+ lr_g = self.lr_g_factor*self.learning_rate
+ print("lr_d", lr_d)
+ print("lr_g", lr_g)
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr_g, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr_d, betas=(0.5, 0.9))
+
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ {
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ ]
+ return [opt_ae, opt_disc], scheduler
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if only_inputs:
+ log["inputs"] = x
+ return log
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ if plot_ema:
+ with self.ema_scope():
+ xrec_ema, _ = self(x)
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ log["reconstructions_ema"] = xrec_ema
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQModelInterface(VQModel):
+ def __init__(self, embed_dim, *args, **kwargs):
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ self.embed_dim = embed_dim
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, h, force_not_quantize=False):
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+setattr(ldm.models.autoencoder, "VQModel", VQModel)
+setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)