aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/LDSR/sd_hijack_autoencoder.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-10 08:25:25 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-10 08:25:25 +0300
commit96d6ca4199e7c5eee8d451618de5161cea317c40 (patch)
tree8f101a345bcd1d66f4047b5e20918e2058e4dc7c /extensions-builtin/LDSR/sd_hijack_autoencoder.py
parent762265eab58cdb8f2d6398769bab43d8b8db0075 (diff)
manual fixes for ruff
Diffstat (limited to 'extensions-builtin/LDSR/sd_hijack_autoencoder.py')
-rw-r--r--extensions-builtin/LDSR/sd_hijack_autoencoder.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index db2231dd..6303fed5 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -1,16 +1,21 @@
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
-
+import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
+
+from torch.optim.lr_scheduler import LambdaLR
+
+from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
+from packaging import version
class VQModel(pl.LightningModule):
def __init__(self,
@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ if x.shape[1] > 3:
+ xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log