aboutsummaryrefslogtreecommitdiff
path: root/modules/codeformer/vqgan_arch.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/codeformer/vqgan_arch.py')
-rw-r--r--modules/codeformer/vqgan_arch.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
index 820e6b12..b24a0394 100644
--- a/modules/codeformer/vqgan_arch.py
+++ b/modules/codeformer/vqgan_arch.py
@@ -326,7 +326,7 @@ class Generator(nn.Module):
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
@@ -337,7 +337,7 @@ class VQAutoEncoder(nn.Module):
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
- self.attn_resolutions = attn_resolutions
+ self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,