aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora
diff options
context:
space:
mode:
authorv0xie <28695009+v0xie@users.noreply.github.com>2023-11-15 03:08:50 -0800
committerv0xie <28695009+v0xie@users.noreply.github.com>2023-11-15 03:08:50 -0800
commitd6d0b22e6657fc84039e82ee735a57101bfe7c17 (patch)
tree9a77f9e0266b5ef5ed2ebd13837d0b3929fd169a /extensions-builtin/Lora
parent7edd50f304ebf8a713839035d4e9eacaa98d3762 (diff)
fix: ignore calc_scale() for COFT which has very small alpha
Diffstat (limited to 'extensions-builtin/Lora')
-rw-r--r--extensions-builtin/Lora/network_oft.py16
1 files changed, 5 insertions, 11 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index 93402bb2..c45a8d23 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule):
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
if not is_other_linear:
- #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
- # orig_weight=orig_weight.permute(1, 0)
-
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
- # without this line the results are significantly worse / less accurate
+ # ensure skew-symmetric matrix
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2)
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
@@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule):
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
- #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
- # orig_weight=orig_weight.permute(1, 0)
-
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
else:
@@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule):
return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown(self, orig_weight):
- multiplier = self.multiplier() * self.calc_scale()
- #if self.is_kohya:
- # return self.calc_updown_kohya(orig_weight, multiplier)
- #else:
+ # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it
+ #multiplier = self.multiplier() * self.calc_scale()
+ multiplier = self.multiplier()
+
return self.calc_updown_kb(orig_weight, multiplier)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight