From b297cc3324979ec78d69b2d11dd18030dfad7bcc Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 20:06:42 +0900 Subject: Hypernetworks - fix KeyError in statistics caching Statistics logging has changed to {filename : list[losses]}, so it has to use loss_info[key].pop() --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/hypernetworks/hypernetwork.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 98a7b62e..33827210 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -274,8 +274,8 @@ def log_statistics(loss_info:dict, key, value): loss_info[key] = [value] else: loss_info[key].append(value) - if len(loss_info) > 1024: - loss_info.pop(0) + if len(loss_info[key]) > 1024: + loss_info[key].pop(0) def statistics(data): -- cgit v1.2.1 From 40b56c9289bf9458ae5ef3c1990ccea851c6c3e2 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:07:07 +0900 Subject: cleanup some code --- modules/hypernetworks/hypernetwork.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) (limited to 'modules/hypernetworks/hypernetwork.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 33827210..4072bf54 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from collections import defaultdict, deque from statistics import stdev, mean class HypernetworkModule(torch.nn.Module): @@ -269,15 +270,6 @@ def stack_conds(conds): return torch.stack(conds) -def log_statistics(loss_info:dict, key, value): - if key not in loss_info: - loss_info[key] = [value] - else: - loss_info[key].append(value) - if len(loss_info[key]) > 1024: - loss_info[key].pop(0) - - def statistics(data): total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" recent_data = data[-32:] @@ -341,7 +333,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log weight.requires_grad = True size = len(ds.indexes) - loss_dict = {} + loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) @@ -383,7 +375,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() for entry in entries: - log_statistics(loss_dict, entry.filename, loss.item()) + loss_dict[entry.filename].append(loss.item()) optimizer.zero_grad() weights[0].grad = None -- cgit v1.2.1 From 348f89c8d40397c1875cff4a7331018785f9c3b8 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:29:53 +0900 Subject: statistics for pbar --- modules/hypernetworks/hypernetwork.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules/hypernetworks/hypernetwork.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4072bf54..48b56029 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log size = len(ds.indexes) loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) + previous_mean_losses = [0] previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) @@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log for i, entries in pbar: hypernetwork.step = i + ititial_step if len(loss_dict) > 0: - previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) + previous_mean_losses = [i[-1] for i in loss_dict.values()] + previous_mean_loss = mean(previous_mean_losses) scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: @@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}") + + if len(previous_mean_losses) > 1: + std = stdev(previous_mean_losses) + else: + std = 0 + dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" + pbar.set_description(dataset_loss_info) if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. -- cgit v1.2.1 From 0d2e1dac407a0e2f5b148d314715f0457b2525b7 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:41:39 +0900 Subject: convert deque -> list I don't feel this being efficient --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/hypernetworks/hypernetwork.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 48b56029..fb510fa7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -282,7 +282,7 @@ def report_statistics(loss_info:dict): for key in keys: try: print("Loss statistics for file " + key) - info, recent = statistics(loss_info[key]) + info, recent = statistics(list(loss_info[key])) print(info) print(recent) except Exception as e: -- cgit v1.2.1 From e9a410b5357612f63528015c5533c2185dcff92e Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:47:39 +0900 Subject: check length for variance --- modules/hypernetworks/hypernetwork.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules/hypernetworks/hypernetwork.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fb510fa7..d647ea55 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -271,9 +271,17 @@ def stack_conds(conds): def statistics(data): - total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" + if len(data) < 2: + std = 0 + else: + std = stdev(data) + total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})" recent_data = data[-32:] - recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})" + if len(recent_data) < 2: + std = 0 + else: + std = stdev(recent_data) + recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})" return total_information, recent_information -- cgit v1.2.1