aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorv0xie <28695009+v0xie@users.noreply.github.com>2023-10-18 04:56:53 -0700
committerv0xie <28695009+v0xie@users.noreply.github.com>2023-10-18 04:56:53 -0700
commiteb01d7f0e0fb46285985803296a25715165fb3f9 (patch)
tree6f4f43f0b99365cbc50ec318e00bf8aabf87c7d0
parent853e21d98eada4db9a9fd1ae8eda90cf763e2818 (diff)
faster by calculating R in updown and using cached R in forward
-rw-r--r--extensions-builtin/Lora/network_oft.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index 68efb1db..fd5b0c0f 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -58,17 +58,18 @@ class NetworkModuleOFT(network.NetworkModule):
def calc_updown(self, orig_weight):
# this works
- R = self.R
+ # R = self.R
+ self.R = self.get_weight(self.multiplier())
- # this causes major deepfrying i.e. just doesn't work
+ # sending R to device causes major deepfrying i.e. just doesn't work
# R = self.R.to(orig_weight.device, dtype=orig_weight.dtype)
- if orig_weight.dim() == 4:
- weight = torch.einsum("oihw, op -> pihw", orig_weight, R)
- else:
- weight = torch.einsum("oi, op -> pi", orig_weight, R)
+ # if orig_weight.dim() == 4:
+ # weight = torch.einsum("oihw, op -> pihw", orig_weight, R)
+ # else:
+ # weight = torch.einsum("oi, op -> pi", orig_weight, R)
- updown = orig_weight @ R
+ updown = orig_weight @ self.R
output_shape = self.oft_blocks.shape
## this works