aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora')
-rw-r--r--extensions-builtin/Lora/networks.py37
1 files changed, 29 insertions, 8 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index ba621139..1645b822 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -277,7 +277,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
self.weight.copy_(weights_backup)
if bias_backup is not None:
- self.bias.copy_(bias_backup)
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias.copy_(bias_backup)
+ else:
+ self.bias.copy_(bias_backup)
+ else:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias = None
+ else:
+ self.bias = None
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
@@ -305,7 +313,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and getattr(self, 'bias', None) is not None:
- bias_backup = self.bias.to(devices.cpu, copy=True)
+ if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
+ bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
+ elif getattr(self, 'bias', None) is not None:
+ bias_backup = self.bias.to(devices.cpu, copy=True)
+ else:
+ bias_backup = None
self.network_bias_backup = bias_backup
if current_names != wanted_names:
@@ -322,8 +335,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
- if ex_bias is not None and getattr(self, 'bias', None) is not None:
- self.bias += ex_bias
+ if ex_bias is not None and hasattr(self, 'bias'):
+ if self.bias is None:
+ self.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.bias += ex_bias
continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
@@ -333,14 +349,19 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
with torch.no_grad():
- updown_q = module_q.calc_updown(self.in_proj_weight)
- updown_k = module_k.calc_updown(self.in_proj_weight)
- updown_v = module_v.calc_updown(self.in_proj_weight)
+ updown_q, _ = module_q.calc_updown(self.in_proj_weight)
+ updown_k, _ = module_k.calc_updown(self.in_proj_weight)
+ updown_v, _ = module_v.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out = module_out.calc_updown(self.out_proj.weight)
+ updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
self.in_proj_weight += updown_qkv
self.out_proj.weight += updown_out
+ if ex_bias is not None:
+ if self.out_proj.bias is None:
+ self.out_proj.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.out_proj.bias += ex_bias
continue
if module is None: