aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAngelBottomless <35677394+aria1th@users.noreply.github.com>2022-10-23 21:29:53 +0900
committerAUTOMATIC1111 <16777216c@gmail.com>2022-10-24 09:07:39 +0300
commit348f89c8d40397c1875cff4a7331018785f9c3b8 (patch)
tree90ef350a5b3512b11af9588a8a7b77c9b69f4f83 /modules
parent40b56c9289bf9458ae5ef3c1990ccea851c6c3e2 (diff)
statistics for pbar
Diffstat (limited to 'modules')
-rw-r--r--modules/hypernetworks/hypernetwork.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 4072bf54..48b56029 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
+ previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
@@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
for i, entries in pbar:
hypernetwork.step = i + ititial_step
if len(loss_dict) > 0:
- previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
+ previous_mean_losses = [i[-1] for i in loss_dict.values()]
+ previous_mean_loss = mean(previous_mean_losses)
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
@@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
- pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
+
+ if len(previous_mean_losses) > 1:
+ std = stdev(previous_mean_losses)
+ else:
+ std = 0
+ dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
+ pbar.set_description(dataset_loss_info)
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.