aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorv0xie <28695009+v0xie@users.noreply.github.com>2023-10-19 13:13:02 -0700
committerv0xie <28695009+v0xie@users.noreply.github.com>2023-10-19 13:13:02 -0700
commit0550659ce6e1c37d1ab05cb8a2cb31d499fa552f (patch)
tree4c962d4e7133e4943dce1c85807b623879642207
parentd10c4db57ed08234a7aed5f530f269ff78544ab0 (diff)
style: fix ambiguous variable name
-rw-r--r--extensions-builtin/Lora/network_oft.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index 0a87958e..4e8382c1 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -43,8 +43,8 @@ class NetworkModuleOFT(network.NetworkModule):
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
- I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
- block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
+ m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
+ block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse())
#block_R_weighted = multiplier * block_R + (1 - multiplier) * I
#R = torch.block_diag(*block_R_weighted)
R = torch.block_diag(*block_R)