aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index cd920df5..5f71b6aa 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -25,6 +25,7 @@ from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
+ "linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
@@ -431,7 +432,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step()
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ steps_done = hypernetwork.step + 1
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
@@ -441,9 +444,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
- if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
@@ -452,8 +455,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
"learn_rate": scheduler.learn_rate
})
- if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()