aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/SwinIR/swinir_model_arch.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/SwinIR/swinir_model_arch.py')
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
index 863f42db..75f7bedc 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch.py
@@ -644,13 +644,17 @@ class SwinIR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ embed_dim=96, depths=None, num_heads=None,
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
+
+ depths = depths or [6, 6, 6, 6]
+ num_heads = num_heads or [6, 6, 6, 6]
+
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64