aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorv0xie <28695009+v0xie@users.noreply.github.com>2023-10-21 13:43:31 -0700
committerv0xie <28695009+v0xie@users.noreply.github.com>2023-10-21 13:43:31 -0700
commit2d8c894b274d60a3e3563a2ace23c4ebcea9e652 (patch)
treed812648a1cf56624fe9fcb6b6e6f2975f79aca28 /extensions-builtin
parent0550659ce6e1c37d1ab05cb8a2cb31d499fa552f (diff)
refactor: use forward hook instead of custom forward
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network_oft.py33
1 files changed, 24 insertions, 9 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index 4e8382c1..8e561ab0 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -36,9 +36,11 @@ class NetworkModuleOFT(network.NetworkModule):
# how do we revert this to unload the weights?
def apply_to(self):
self.org_forward = self.org_module[0].forward
- self.org_module[0].forward = self.forward
+ #self.org_module[0].forward = self.forward
+ self.org_module[0].register_forward_hook(self.forward_hook)
def get_weight(self, oft_blocks, multiplier=None):
+ self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
@@ -66,14 +68,10 @@ class NetworkModuleOFT(network.NetworkModule):
output_shape = self.oft_blocks.shape
return self.finalize_updown(updown, orig_weight, output_shape)
-
- def forward(self, x, y=None):
- x = self.org_forward(x)
- if self.multiplier() == 0.0:
- return x
-
- # calculating R here is excruciatingly slow
- #R = self.get_weight().to(x.device, dtype=x.dtype)
+
+ def forward_hook(self, module, args, output):
+ #print(f'Forward hook in {self.network_key} called')
+ x = output
R = self.R.to(x.device, dtype=x.dtype)
if x.dim() == 4:
@@ -83,3 +81,20 @@ class NetworkModuleOFT(network.NetworkModule):
else:
x = torch.matmul(x, R)
return x
+
+ # def forward(self, x, y=None):
+ # x = self.org_forward(x)
+ # if self.multiplier() == 0.0:
+ # return x
+
+ # # calculating R here is excruciatingly slow
+ # #R = self.get_weight().to(x.device, dtype=x.dtype)
+ # R = self.R.to(x.device, dtype=x.dtype)
+
+ # if x.dim() == 4:
+ # x = x.permute(0, 2, 3, 1)
+ # x = torch.matmul(x, R)
+ # x = x.permute(0, 3, 1, 2)
+ # else:
+ # x = torch.matmul(x, R)
+ # return x