aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network_oft.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index 68efb1db..fd5b0c0f 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -58,17 +58,18 @@ class NetworkModuleOFT(network.NetworkModule):
def calc_updown(self, orig_weight):
# this works
- R = self.R
+ # R = self.R
+ self.R = self.get_weight(self.multiplier())
- # this causes major deepfrying i.e. just doesn't work
+ # sending R to device causes major deepfrying i.e. just doesn't work
# R = self.R.to(orig_weight.device, dtype=orig_weight.dtype)
- if orig_weight.dim() == 4:
- weight = torch.einsum("oihw, op -> pihw", orig_weight, R)
- else:
- weight = torch.einsum("oi, op -> pi", orig_weight, R)
+ # if orig_weight.dim() == 4:
+ # weight = torch.einsum("oihw, op -> pihw", orig_weight, R)
+ # else:
+ # weight = torch.einsum("oi, op -> pi", orig_weight, R)
- updown = orig_weight @ R
+ updown = orig_weight @ self.R
output_shape = self.oft_blocks.shape
## this works