aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorv0xie <28695009+v0xie@users.noreply.github.com>2023-10-18 04:27:44 -0700
committerv0xie <28695009+v0xie@users.noreply.github.com>2023-10-18 04:27:44 -0700
commit853e21d98eada4db9a9fd1ae8eda90cf763e2818 (patch)
treed507292825267486dfe4acccb51f604b9c80e30e /extensions-builtin
parent1c6efdbba774d603c592debaccd6f5ad827bd1b2 (diff)
faster by using cached R in forward
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network_oft.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index f085eca5..68efb1db 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -57,21 +57,32 @@ class NetworkModuleOFT(network.NetworkModule):
return R
def calc_updown(self, orig_weight):
+ # this works
R = self.R
+
+ # this causes major deepfrying i.e. just doesn't work
+ # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype)
+
if orig_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", orig_weight, R)
else:
weight = torch.einsum("oi, op -> pi", orig_weight, R)
+
updown = orig_weight @ R
- output_shape = [orig_weight.size(0), R.size(1)]
- #output_shape = [R.size(0), orig_weight.size(1)]
+ output_shape = self.oft_blocks.shape
+
+ ## this works
+ # updown = orig_weight @ R
+ # output_shape = [orig_weight.size(0), R.size(1)]
+
return self.finalize_updown(updown, orig_weight, output_shape)
def forward(self, x, y=None):
x = self.org_forward(x)
if self.multiplier() == 0.0:
return x
- R = self.get_weight().to(x.device, dtype=x.dtype)
+ #R = self.get_weight().to(x.device, dtype=x.dtype)
+ R = self.R.to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)