aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authorMelan <alexleander91@gmail.com>2022-10-13 12:37:58 +0200
committerMelan <alexleander91@gmail.com>2022-10-13 12:37:58 +0200
commit8636b50aea83f9c743f005722d9f3f8ee9303e00 (patch)
tree01afcd9e7cecdc6bac22af40da01c4ef2d49aa82 /modules/hypernetworks/hypernetwork.py
parent1cfc2a18981ee56bdb69a2de7b463a11ad05e329 (diff)
Add learn_rate to csv and removed a left-over debug statement
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6522078f..2751a8c8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -257,19 +257,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
- print(f"{write_csv_every} > {hypernetwork.step % write_csv_every == 0}, {write_csv_every}")
if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
- csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss"])
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
if write_csv_header:
csv_writer.writeheader()
csv_writer.writerow({"step": hypernetwork.step,
- "loss": f"{losses.mean():.7f}"})
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')