aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r--extensions-builtin/Lora/networks.py142
1 files changed, 111 insertions, 31 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index bc722e90..22fdff4a 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -1,3 +1,4 @@
+import logging
import os
import re
@@ -7,6 +8,7 @@ import network_hada
import network_ia3
import network_lokr
import network_full
+import network_norm
import torch
from typing import Union
@@ -19,6 +21,7 @@ module_types = [
network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
+ network_norm.ModuleTypeNorm(),
]
@@ -31,6 +34,8 @@ suffix_conversion = {
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
+ "norm1": "in_layers_0",
+ "norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
@@ -190,7 +195,7 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
if keys_failed_to_match:
- print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
+ logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
return net
@@ -203,7 +208,6 @@ def purge_networks_from_memory():
devices.torch_gc()
-
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {}
@@ -244,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if net is None:
failed_to_load_networks.append(name)
- print(f"Couldn't find network with name {name}")
+ logging.info(f"Couldn't find network with name {name}")
continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@@ -253,25 +257,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.append(net)
if failed_to_load_networks:
- sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
+ sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
-def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
+ bias_backup = getattr(self, "network_bias_backup", None)
- if weights_backup is None:
+ if weights_backup is None and bias_backup is None:
return
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
+ if weights_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+ if bias_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias.copy_(bias_backup)
+ else:
+ self.bias.copy_(bias_backup)
else:
- self.weight.copy_(weights_backup)
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias = None
+ else:
+ self.bias = None
-def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
@@ -294,21 +311,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if bias_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
+ bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
+ elif getattr(self, 'bias', None) is not None:
+ bias_backup = self.bias.to(devices.cpu, copy=True)
+ else:
+ bias_backup = None
+ self.network_bias_backup = bias_backup
+
if current_names != wanted_names:
network_restore_weights_from_backup(self)
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
- with torch.no_grad():
- updown = module.calc_updown(self.weight)
-
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
- # inpainting model. zero pad updown to make channel[1] 4 to 9
- updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+ try:
+ with torch.no_grad():
+ updown, ex_bias = module.calc_updown(self.weight)
+
+ if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+
+ self.weight += updown
+ if ex_bias is not None and hasattr(self, 'bias'):
+ if self.bias is None:
+ self.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.bias += ex_bias
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- self.weight += updown
- continue
+ continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None)
@@ -316,21 +353,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
- with torch.no_grad():
- updown_q = module_q.calc_updown(self.in_proj_weight)
- updown_k = module_k.calc_updown(self.in_proj_weight)
- updown_v = module_v.calc_updown(self.in_proj_weight)
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out = module_out.calc_updown(self.out_proj.weight)
-
- self.in_proj_weight += updown_qkv
- self.out_proj.weight += updown_out
- continue
+ try:
+ with torch.no_grad():
+ updown_q, _ = module_q.calc_updown(self.in_proj_weight)
+ updown_k, _ = module_k.calc_updown(self.in_proj_weight)
+ updown_v, _ = module_v.calc_updown(self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+ updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
+
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += updown_out
+ if ex_bias is not None:
+ if self.out_proj.bias is None:
+ self.out_proj.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.out_proj.bias += ex_bias
+
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+
+ continue
if module is None:
continue
- print(f'failed to calculate network weights for layer {network_layer_name}')
+ logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names
@@ -357,7 +406,7 @@ def network_forward(module, input, original_forward):
if module is None:
continue
- y = module.forward(y, input)
+ y = module.forward(input, y)
return y
@@ -397,6 +446,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+def network_GroupNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.GroupNorm_forward_before_network(self, input)
+
+
+def network_GroupNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def network_LayerNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.LayerNorm_forward_before_network(self, input)
+
+
+def network_LayerNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
+
+
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
@@ -473,6 +552,7 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
+extra_network_lora = None
available_networks = {}
available_network_aliases = {}