aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/networks.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-13 02:27:39 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-13 02:27:39 +0800
commitbd4da4474bef5c9c1f690c62b971704ee73d2860 (patch)
treef3624d0521366fb23c4e7861ea1d0a04b43483e6 /extensions-builtin/Lora/networks.py
parentb2080756fcdc328292fc38998c06ccf23e53bd7e (diff)
Add extra norm module into built-in lora ext
refer to LyCORIS 1.9.0.dev6 add new option and module for training norm layer (Which is reported to be good for style)
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r--extensions-builtin/Lora/networks.py64
1 files changed, 55 insertions, 9 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 7e3415ac..74cefe43 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -7,6 +7,7 @@ import network_hada
import network_ia3
import network_lokr
import network_full
+import network_norm
import torch
from typing import Union
@@ -19,6 +20,7 @@ module_types = [
network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
+ network_norm.ModuleTypeNorm(),
]
@@ -31,6 +33,8 @@ suffix_conversion = {
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
+ "norm1": "in_layers_0",
+ "norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
@@ -258,20 +262,25 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory()
-def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
+ bias_backup = getattr(self, "network_bias_backup", None)
- if weights_backup is None:
+ if weights_backup is None and bias_backup is None:
return
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
+ if weights_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+ if bias_backup is not None:
+ self.bias.copy_(bias_backup)
-def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
@@ -294,6 +303,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if bias_backup is None and getattr(self, 'bias', None) is not None:
+ bias_backup = self.bias.to(devices.cpu, copy=True)
+ self.network_bias_backup = bias_backup
+
if current_names != wanted_names:
network_restore_weights_from_backup(self)
@@ -301,13 +315,15 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
with torch.no_grad():
- updown = module.calc_updown(self.weight)
+ updown, ex_bias = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
+ if getattr(self, 'bias', None) is not None:
+ self.bias += ex_bias
continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
@@ -397,6 +413,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+def network_GroupNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.GroupNorm_forward_before_network(self, input)
+
+
+def network_GroupNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def network_LayerNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.LayerNorm_forward_before_network(self, input)
+
+
+def network_LayerNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
+
+
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)