aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3a44b377..905cbeef 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -35,13 +35,13 @@ class HypernetworkModule(torch.nn.Module):
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
- if i < 1: continue
+ # if i < 1: continue
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
- linears.append(torch.nn.Dropout(p=0.3))
+ # linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if isinstance(layer, torch.nn.Linear):
layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -304,8 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- # if optimizer == "Adam": or else Adam / AdamW / etc...
- optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar: