aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/SwinIR/swinir_model_arch_v2.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-10 11:19:16 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-10 11:19:16 +0300
commit550256db1ce18778a9d56ff343d844c61b9f9b83 (patch)
treea17e8fd9cb475381c361844970ba2d9111938b6d /extensions-builtin/SwinIR/swinir_model_arch_v2.py
parent028d3f6425d85f122027c127fba8bcbf4f66ee75 (diff)
ruff manual fixes
Diffstat (limited to 'extensions-builtin/SwinIR/swinir_model_arch_v2.py')
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch_v2.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
index 0e28ae6e..d4c0b0da 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
@@ -74,9 +74,12 @@ class WindowAttention(nn.Module):
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
+ pretrained_window_size=None):
super().__init__()
+
+ pretrained_window_size = pretrained_window_size or [0, 0]
+
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
@@ -698,13 +701,17 @@ class Swin2SR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ embed_dim=96, depths=None, num_heads=None,
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(Swin2SR, self).__init__()
+
+ depths = depths or [6, 6, 6, 6]
+ num_heads = num_heads or [6, 6, 6, 6]
+
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64