aboutsummaryrefslogtreecommitdiff
path: root/modules/codeformer/codeformer_arch.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-10 11:19:16 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-10 11:19:16 +0300
commit550256db1ce18778a9d56ff343d844c61b9f9b83 (patch)
treea17e8fd9cb475381c361844970ba2d9111938b6d /modules/codeformer/codeformer_arch.py
parent028d3f6425d85f122027c127fba8bcbf4f66ee75 (diff)
ruff manual fixes
Diffstat (limited to 'modules/codeformer/codeformer_arch.py')
-rw-r--r--modules/codeformer/codeformer_arch.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
index 00c407de..ff1c0b4b 100644
--- a/modules/codeformer/codeformer_arch.py
+++ b/modules/codeformer/codeformer_arch.py
@@ -161,10 +161,13 @@ class Fuse_sft_block(nn.Module):
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
- connect_list=['32', '64', '128', '256'],
- fix_modules=['quantize','generator']):
+ connect_list=None,
+ fix_modules=None):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+ connect_list = connect_list or ['32', '64', '128', '256']
+ fix_modules = fix_modules or ['quantize', 'generator']
+
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():