aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAngelBottomless <35677394+aria1th@users.noreply.github.com>2022-10-23 21:07:07 +0900
committerAUTOMATIC1111 <16777216c@gmail.com>2022-10-24 09:07:39 +0300
commit40b56c9289bf9458ae5ef3c1990ccea851c6c3e2 (patch)
tree7021c99f5e0bf447c8e5b81ae4083678f68105aa
parentb297cc3324979ec78d69b2d11dd18030dfad7bcc (diff)
cleanup some code
-rw-r--r--modules/hypernetworks/hypernetwork.py14
1 files changed, 3 insertions, 11 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 33827210..4072bf54 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
+from collections import defaultdict, deque
from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
@@ -269,15 +270,6 @@ def stack_conds(conds):
return torch.stack(conds)
-def log_statistics(loss_info:dict, key, value):
- if key not in loss_info:
- loss_info[key] = [value]
- else:
- loss_info[key].append(value)
- if len(loss_info[key]) > 1024:
- loss_info[key].pop(0)
-
-
def statistics(data):
total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
recent_data = data[-32:]
@@ -341,7 +333,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weight.requires_grad = True
size = len(ds.indexes)
- loss_dict = {}
+ loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
@@ -383,7 +375,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
for entry in entries:
- log_statistics(loss_dict, entry.filename, loss.item())
+ loss_dict[entry.filename].append(loss.item())
optimizer.zero_grad()
weights[0].grad = None