aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authordiscus0434 <discus0434@gmail.com>2022-10-22 13:44:39 +0000
committerdiscus0434 <discus0434@gmail.com>2022-10-22 13:44:39 +0000
commit6a4fa73a38935a18779ce1809892730fd1572bee (patch)
tree4a2a26129fc5656d4a07c95d0ea205725233074b /modules
parent97749b7c7d9e0b27613aa79197f6094b4f6441d8 (diff)
small fix
Diffstat (limited to 'modules')
-rw-r--r--modules/hypernetworks/hypernetwork.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3372aae2..3bc71ee5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- # Add dropout
- if use_dropout:
- p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
- linears.append(torch.nn.Dropout(p=p))
+ # Add dropout expect last layer
+ if use_dropout and i < len(layer_structure) - 3:
+ linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)