aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r--modules/hypernetworks/hypernetwork.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b7a04038..3132a56c 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module):
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
linears = []
for i in range(len(layer_structure) - 1):
@@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func
if activation_func == "linear" or activation_func is None:
pass
+ # If ReLU, Skip adding it to the first layer to avoid dying ReLU
+ elif activation_func == "relu" and i < 1:
+ pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
- raise RuntimeError(
- "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
- )
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add dropout
if use_dropout:
@@ -166,8 +166,8 @@ class Hypernetwork:
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
- HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)