aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py79
1 files changed, 63 insertions, 16 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b8695fc1..7d519cd9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -22,45 +22,86 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
- def __init__(self, dim, state_dict=None):
+ def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
super().__init__()
+ if layer_structure is not None:
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+ else:
+ layer_structure = parse_layer_structure(dim, state_dict)
+
+ linears = []
+ for i in range(len(layer_structure) - 1):
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
+ self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- self.load_state_dict(state_dict, strict=True)
+ try:
+ self.load_state_dict(state_dict)
+ except RuntimeError:
+ self.try_load_previous(state_dict)
else:
-
- self.linear1.weight.data.normal_(mean=0.0, std=0.01)
- self.linear1.bias.data.zero_()
- self.linear2.weight.data.normal_(mean=0.0, std=0.01)
- self.linear2.bias.data.zero_()
+ for layer in self.linear:
+ layer.weight.data.normal_(mean = 0.0, std = 0.01)
+ layer.bias.data.zero_()
self.to(devices.device)
+ def try_load_previous(self, state_dict):
+ states = self.state_dict()
+ states['linear.0.bias'].copy_(state_dict['linear1.bias'])
+ states['linear.0.weight'].copy_(state_dict['linear1.weight'])
+ states['linear.1.bias'].copy_(state_dict['linear2.bias'])
+ states['linear.1.weight'].copy_(state_dict['linear2.weight'])
+
def forward(self, x):
- return x + (self.linear2(self.linear1(x))) * self.multiplier
+ return x + self.linear(x) * self.multiplier
+
+ def trainables(self):
+ layer_structure = []
+ for layer in self.linear:
+ layer_structure += [layer.weight, layer.bias]
+ return layer_structure
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
+def parse_layer_structure(dim, state_dict):
+ i = 0
+ layer_structure = [1]
+
+ while (key := "linear.{}.weight".format(i)) in state_dict:
+ weight = state_dict[key]
+ layer_structure.append(len(weight) // dim)
+ i += 1
+
+ return layer_structure
+
+
class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
self.filename = None
self.name = name
self.layers = {}
self.step = 0
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.layer_structure = layer_structure
+ self.add_layer_norm = add_layer_norm
for size in enable_sizes or []:
- self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+ self.layers[size] = (
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ )
def weights(self):
res = []
@@ -68,7 +109,7 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
layer.train()
- res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+ res += layer.trainables()
return res
@@ -80,6 +121,8 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
+ state_dict['layer_structure'] = self.layer_structure
+ state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -94,10 +137,15 @@ class Hypernetwork:
for size, sd in state_dict.items():
if type(size) == int:
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+ self.layers[size] = (
+ HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ )
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
+ self.layer_structure = state_dict.get('layer_structure', None)
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
@@ -226,7 +274,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
-
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -261,7 +308,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
-# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x