From 2ce52d32e41fb523d1494f45073fd18496e52d35 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 16:31:12 +0000 Subject: fix for #3086 failing to load any previous hypernet --- modules/hypernetworks/hypernetwork.py | 60 ++++++++++++++++------------------- 1 file changed, 28 insertions(+), 32 deletions(-) (limited to 'modules/hypernetworks/hypernetwork.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d519cd9..74300122 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): super().__init__() - if layer_structure is not None: - assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - else: - layer_structure = parse_layer_structure(dim, state_dict) + + assert layer_structure is not None, "layer_structure mut not be None" + assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" linears = [] for i in range(len(layer_structure) - 1): @@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module): self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - try: - self.load_state_dict(state_dict) - except RuntimeError: - self.try_load_previous(state_dict) + self.fix_old_state_dict(state_dict) + self.load_state_dict(state_dict) else: for layer in self.linear: - layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() self.to(devices.device) - def try_load_previous(self, state_dict): - states = self.state_dict() - states['linear.0.bias'].copy_(state_dict['linear1.bias']) - states['linear.0.weight'].copy_(state_dict['linear1.weight']) - states['linear.1.bias'].copy_(state_dict['linear2.bias']) - states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def fix_old_state_dict(self, state_dict): + changes = { + 'linear1.bias': 'linear.0.bias', + 'linear1.weight': 'linear.0.weight', + 'linear2.bias': 'linear.1.bias', + 'linear2.weight': 'linear.1.weight', + } + + for fr, to in changes.items(): + x = state_dict.get(fr, None) + if x is None: + continue + + del state_dict[fr] + state_dict[to] = x def forward(self, x): return x + self.linear(x) * self.multiplier @@ -71,18 +77,6 @@ def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength -def parse_layer_structure(dim, state_dict): - i = 0 - layer_structure = [1] - - while (key := "linear.{}.weight".format(i)) in state_dict: - weight = state_dict[key] - layer_structure.append(len(weight) // dim) - i += 1 - - return layer_structure - - class Hypernetwork: filename = None name = None @@ -135,17 +129,18 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') + self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) + self.add_layer_norm = state_dict.get('is_layer_norm', False) + for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), - HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), + HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) - self.layer_structure = state_dict.get('layer_structure', None) - self.add_layer_norm = state_dict.get('is_layer_norm', False) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) @@ -244,6 +239,7 @@ def stack_conds(conds): return torch.stack(conds) + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' -- cgit v1.2.1 From 6f98e89486f55b0e4657e96ce640cf1c4675d187 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 20 Oct 2022 00:10:45 +0000 Subject: update --- modules/hypernetworks/hypernetwork.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) (limited to 'modules/hypernetworks/hypernetwork.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74300122..7d617680 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): super().__init__() - assert layer_structure is not None, "layer_structure mut not be None" + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if activation_func == "relu": + linears.append(torch.nn.ReLU()) + if activation_func == "leakyrelu": + linears.append(torch.nn.LeakyReLU()) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - layer.weight.data.normal_(mean=0.0, std=0.01) - layer.bias.data.zero_() + if not "ReLU" in layer.__str__(): + layer.weight.data.normal_(mean=0.0, std=0.01) + layer.bias.data.zero_() self.to(devices.device) @@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - layer_structure += [layer.weight, layer.bias] + if not "ReLU" in layer.__str__(): + layer_structure += [layer.weight, layer.bias] return layer_structure @@ -81,7 +87,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): self.filename = None self.name = name self.layers = {} @@ -90,11 +96,12 @@ class Hypernetwork: self.sd_checkpoint_name = None self.layer_structure = layer_structure self.add_layer_norm = add_layer_norm + self.activation_func = activation_func for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), ) def weights(self): @@ -117,6 +124,7 @@ class Hypernetwork: state_dict['name'] = self.name state_dict['layer_structure'] = self.layer_structure state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['activation_func'] = self.activation_func state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -131,12 +139,13 @@ class Hypernetwork: self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.add_layer_norm = state_dict.get('is_layer_norm', False) + self.activation_func = state_dict.get('activation_func', None) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), - HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func), ) self.name = state_dict.get('name', self.name) -- cgit v1.2.1