aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r--modules/hypernetworks/hypernetwork.py154
-rw-r--r--modules/hypernetworks/ui.py15
2 files changed, 130 insertions, 39 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74300122..98a7b62e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,40 +1,61 @@
+import csv
import datetime
import glob
import html
import os
import sys
import traceback
-import tqdm
-import csv
-
-import torch
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
+import modules.textual_inversion.dataset
import torch
-from torch import einsum
+import tqdm
from einops import rearrange, repeat
-import modules.textual_inversion.dataset
+from ldm.util import default
+from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
+from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
-
- def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
+ activation_dict = {
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ }
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__()
- assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
+ # Add an activation func
+ if activation_func == "linear" or activation_func is None:
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
+ # Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+ # Add dropout expect last layer
+ if use_dropout and i < len(layer_structure) - 3:
+ linears.append(torch.nn.Dropout(p=0.3))
+
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
@@ -42,8 +63,9 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean=0.0, std=0.01)
- layer.bias.data.zero_()
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer.weight.data.normal_(mean=0.0, std=0.01)
+ layer.bias.data.zero_()
self.to(devices.device)
@@ -69,7 +91,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- layer_structure += [layer.weight, layer.bias]
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -81,7 +104,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -89,12 +112,14 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
+ self.activation_func = activation_func
self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -116,7 +141,9 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
+ state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -130,13 +157,15 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.activation_func = state_dict.get('activation_func', None)
self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.use_dropout = state_dict.get('use_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
@@ -240,7 +269,39 @@ def stack_conds(conds):
return torch.stack(conds)
+def log_statistics(loss_info:dict, key, value):
+ if key not in loss_info:
+ loss_info[key] = [value]
+ else:
+ loss_info[key].append(value)
+ if len(loss_info) > 1024:
+ loss_info.pop(0)
+
+
+def statistics(data):
+ total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
+ recent_data = data[-32:]
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(loss_info[key])
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
+
+
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
+ from modules import images
+
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -279,22 +340,32 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
for weight in weights:
weight.requires_grad = True
- losses = torch.zeros((32,))
+ size = len(ds.indexes)
+ loss_dict = {}
+ losses = torch.zeros((size,))
+ previous_mean_loss = 0
+ print("Mean loss of {} elements".format(size))
last_saved_file = "<none>"
last_saved_image = "<none>"
+ forced_filename = "<none>"
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ steps_without_grad = 0
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
-
+ if len(loss_dict) > 0:
+ previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
+
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
@@ -311,26 +382,39 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
-
+ for entry in entries:
+ log_statistics(loss_dict, entry.filename, loss.item())
+
optimizer.zero_grad()
+ weights[0].grad = None
loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
optimizer.step()
- mean_loss = losses.mean()
- if torch.isnan(mean_loss):
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
- pbar.set_description(f"loss: {mean_loss:.7f}")
+ pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{mean_loss:.7f}",
+ "loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+ forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -366,27 +450,29 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- image.save(last_saved_image)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
-Loss: {mean_loss:.7f}<br/>
+Loss: {previous_mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
-Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
-
+
+ report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
+ hypernetwork.name = hypernetwork_name
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename
-
-
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index e0741d08..2b472d87 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -3,16 +3,19 @@ import os
import re
import gradio as gr
-
-import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
@@ -21,7 +24,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
+ activation_func=activation_func,
add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
)
hypernet.save(fn)