From de096d0ce752c96e45508dcc7b9e84f7dbe10cca Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Tue, 25 Oct 2022 14:48:49 +0900 Subject: Weight initialization and More activation func add weight init add weight init option in create_hypernetwork fstringify hypernet info save weight initialization info for further debugging fill bias with zero for He/Xavier initialize LayerNorm with Normal fix loading weight_init --- modules/hypernetworks/hypernetwork.py | 47 ++++++++++++++++++++++++++++------- modules/hypernetworks/ui.py | 4 ++- 2 files changed, 41 insertions(+), 10 deletions(-) (limited to 'modules/hypernetworks') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index d647ea55..afbcdff8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -5,6 +5,7 @@ import html import os import sys import traceback +import inspect import modules.textual_inversion.dataset import torch @@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ from collections import defaultdict, deque from statistics import stdev, mean + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module): "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, } + activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module): else: for layer in self.linear: if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: - layer.weight.data.normal_(mean=0.0, std=0.01) - layer.bias.data.zero_() - + w, b = layer.weight.data, layer.bias.data + if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm: + normal_(w, mean=0.0, std=0.01) + normal_(b, mean=0.0, std=0.005) + elif weight_init == 'XavierUniform': + xavier_uniform_(w) + zeros_(b) + elif weight_init == 'XavierNormal': + xavier_normal_(w) + zeros_(b) + elif weight_init == 'KaimingUniform': + kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') + zeros_(b) + elif weight_init == 'KaimingNormal': + kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') + zeros_(b) + else: + raise KeyError(f"Key {weight_init} is not defined as initialization!") self.to(devices.device) def fix_old_state_dict(self, state_dict): @@ -105,7 +126,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): self.filename = None self.name = name self.layers = {} @@ -114,13 +135,14 @@ class Hypernetwork: self.sd_checkpoint_name = None self.layer_structure = layer_structure self.activation_func = activation_func + self.weight_init = weight_init self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), ) def weights(self): @@ -144,6 +166,7 @@ class Hypernetwork: state_dict['layer_structure'] = self.layer_structure state_dict['activation_func'] = self.activation_func state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['weight_initialization'] = self.weight_init state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -158,15 +181,21 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) + print(self.layer_structure) self.activation_func = state_dict.get('activation_func', None) + print(f"Activation function is {self.activation_func}") + self.weight_init = state_dict.get('weight_initialization', 'Normal') + print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) + print(f"Layer norm is set to {self.add_layer_norm}") self.use_dropout = state_dict.get('use_dropout', False) + print(f"Dropout usage is set to {self.use_dropout}" ) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), ) self.name = state_dict.get('name', self.name) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 2b472d87..2c6c0470 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork +keys = list(hypernetwork.HypernetworkModule.activation_dict.keys()) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) @@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, enable_sizes=[int(x) for x in enable_sizes], layer_structure=layer_structure, activation_func=activation_func, + weight_init=weight_init, add_layer_norm=add_layer_norm, use_dropout=use_dropout, ) -- cgit v1.2.1 From 7207e3bf49ed000464d288cd67e02f0ba8614dc3 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Tue, 25 Oct 2022 15:24:59 +0900 Subject: remove duplicate keys and lowercase --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/hypernetworks') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index afbcdff8..842b6447 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,7 +32,7 @@ class HypernetworkModule(torch.nn.Module): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, } - activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) + activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False): super().__init__() -- cgit v1.2.1 From a524d137d0a89bb19a6676dc9b8fbb5d1b580678 Mon Sep 17 00:00:00 2001 From: timntorres Date: Mon, 24 Oct 2022 23:48:05 -0700 Subject: patch bug (SeverianVoid's comment on 5245c7a) --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/hypernetworks') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 842b6447..8113b35b 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -487,7 +487,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step -- cgit v1.2.1