aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordiscus0434 <discus0434@gmail.com>2022-10-22 13:00:44 +0000
committerdiscus0434 <discus0434@gmail.com>2022-10-22 13:00:44 +0000
commit7912acef725832debef58c4c7bf8ec22fb446c0b (patch)
treed85d1455ea4f9a888cca76c5167260a6363c0a4f
parentfccba4729db341a299db3343e3264fecd9459a07 (diff)
small fix
-rw-r--r--modules/hypernetworks/hypernetwork.py12
-rw-r--r--modules/ui.py1
2 files changed, 5 insertions, 8 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3132a56c..7d12e0ff 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -42,22 +42,20 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func
if activation_func == "linear" or activation_func is None:
pass
- # If ReLU, Skip adding it to the first layer to avoid dying ReLU
- elif activation_func == "relu" and i < 1:
- pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
- # Add dropout
- if use_dropout:
- linears.append(torch.nn.Dropout(p=0.3))
-
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+ # Add dropout
+ if use_dropout:
+ p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
+ linears.append(torch.nn.Dropout(p=p))
+
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
diff --git a/modules/ui.py b/modules/ui.py
index cd118552..eca887ca 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1244,7 +1244,6 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):