aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-21 10:13:24 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-10-21 10:13:24 +0300
commit03a1e288c4973dd2dff57a97469b40f146b6fccf (patch)
tree0d35dc8de45f2d34038cff66c439bdc5e0ca4e95 /modules/hypernetworks
parente4877722e3e02b2da1ddacc0c7be25e6559c02f3 (diff)
turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r--modules/hypernetworks/hypernetwork.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3274a802..b1a5d0c7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if type(layer) == torch.nn.Linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if type(layer) == torch.nn.Linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure