From ec718f76b58b183859ed732e11ec748c41a13f76 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 17 Oct 2023 23:35:50 -0700 Subject: wip incorrect OFT implementation --- extensions-builtin/Lora/network_oft.py | 82 ++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 extensions-builtin/Lora/network_oft.py (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py new file mode 100644 index 00000000..9ddb175c --- /dev/null +++ b/extensions-builtin/Lora/network_oft.py @@ -0,0 +1,82 @@ +import torch +import network + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]): + return NetworkModuleOFT(net, weights) + + return None + +# adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + + self.dim = self.oft_blocks.shape[0] + self.num_blocks = self.dim + + #if type(self.alpha) == torch.Tensor: + # self.alpha = self.alpha.detach().numpy() + + if "Linear" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_features + elif "Conv" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_channels + + self.constraint = self.alpha * self.out_dim + self.block_size = self.out_dim // self.num_blocks + + self.oft_multiplier = self.multiplier() + + # replace forward method of original linear rather than replacing the module + # self.org_forward = self.sd_module.forward + # self.sd_module.forward = self.forward + + def get_weight(self): + block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + + return R + + def calc_updown(self, orig_weight): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) + # W = R*W_0 + updown = orig_weight + R + output_shape = [R.size(0), orig_weight.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) + + # def forward(self, x, y=None): + # x = self.org_forward(x) + # if self.oft_multiplier == 0.0: + # return x + + # R = self.get_weight().to(x.device, dtype=x.dtype) + # if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + # else: + # x = torch.matmul(x, R) + # return x -- cgit v1.2.1 From 1c6efdbba774d603c592debaccd6f5ad827bd1b2 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:16:01 -0700 Subject: inference working but SLOW --- extensions-builtin/Lora/network_oft.py | 73 +++++++++++++++++----------------- 1 file changed, 36 insertions(+), 37 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 9ddb175c..f085eca5 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -12,6 +12,7 @@ class ModuleTypeOFT(network.ModuleType): # adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) self.oft_blocks = weights.w["oft_blocks"] @@ -20,24 +21,29 @@ class NetworkModuleOFT(network.NetworkModule): self.dim = self.oft_blocks.shape[0] self.num_blocks = self.dim - #if type(self.alpha) == torch.Tensor: - # self.alpha = self.alpha.detach().numpy() - if "Linear" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_features elif "Conv" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_channels - self.constraint = self.alpha * self.out_dim + self.constraint = self.alpha + #self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.num_blocks - self.oft_multiplier = self.multiplier() + self.org_module: list[torch.Module] = [self.sd_module] + + self.R = self.get_weight() - # replace forward method of original linear rather than replacing the module - # self.org_forward = self.sd_module.forward - # self.sd_module.forward = self.forward + self.apply_to() + + # replace forward method of original linear rather than replacing the module + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward - def get_weight(self): + def get_weight(self, multiplier=None): + if not multiplier: + multiplier = self.multiplier() block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) @@ -45,38 +51,31 @@ class NetworkModuleOFT(network.NetworkModule): I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) return R def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I - R = torch.block_diag(*block_R_weighted) - #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) - # W = R*W_0 - updown = orig_weight + R - output_shape = [R.size(0), orig_weight.size(1)] + R = self.R + if orig_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + else: + weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R + output_shape = [orig_weight.size(0), R.size(1)] + #output_shape = [R.size(0), orig_weight.size(1)] return self.finalize_updown(updown, orig_weight, output_shape) - # def forward(self, x, y=None): - # x = self.org_forward(x) - # if self.oft_multiplier == 0.0: - # return x - - # R = self.get_weight().to(x.device, dtype=x.dtype) - # if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - # else: - # x = torch.matmul(x, R) - # return x + def forward(self, x, y=None): + x = self.org_forward(x) + if self.multiplier() == 0.0: + return x + R = self.get_weight().to(x.device, dtype=x.dtype) + if x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = torch.matmul(x, R) + x = x.permute(0, 3, 1, 2) + else: + x = torch.matmul(x, R) + return x -- cgit v1.2.1 From 853e21d98eada4db9a9fd1ae8eda90cf763e2818 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:27:44 -0700 Subject: faster by using cached R in forward --- extensions-builtin/Lora/network_oft.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f085eca5..68efb1db 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,21 +57,32 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): + # this works R = self.R + + # this causes major deepfrying i.e. just doesn't work + # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + if orig_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", orig_weight, R) else: weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R - output_shape = [orig_weight.size(0), R.size(1)] - #output_shape = [R.size(0), orig_weight.size(1)] + output_shape = self.oft_blocks.shape + + ## this works + # updown = orig_weight @ R + # output_shape = [orig_weight.size(0), R.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: return x - R = self.get_weight().to(x.device, dtype=x.dtype) + #R = self.get_weight().to(x.device, dtype=x.dtype) + R = self.R.to(x.device, dtype=x.dtype) if x.dim() == 4: x = x.permute(0, 2, 3, 1) x = torch.matmul(x, R) -- cgit v1.2.1 From eb01d7f0e0fb46285985803296a25715165fb3f9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:56:53 -0700 Subject: faster by calculating R in updown and using cached R in forward --- extensions-builtin/Lora/network_oft.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 68efb1db..fd5b0c0f 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -58,17 +58,18 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): # this works - R = self.R + # R = self.R + self.R = self.get_weight(self.multiplier()) - # this causes major deepfrying i.e. just doesn't work + # sending R to device causes major deepfrying i.e. just doesn't work # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - if orig_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - else: - weight = torch.einsum("oi, op -> pi", orig_weight, R) + # if orig_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + # else: + # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ R + updown = orig_weight @ self.R output_shape = self.oft_blocks.shape ## this works -- cgit v1.2.1 From 321680ccd0e0404223fbdf4f26498f7d0317fb75 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:41:17 -0700 Subject: refactor: fix constraint, re-use get_weight --- extensions-builtin/Lora/network_oft.py | 40 ++++++++++++++-------------------- 1 file changed, 16 insertions(+), 24 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fd5b0c0f..2af1bc4c 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -9,7 +9,7 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -17,7 +17,6 @@ class NetworkModuleOFT(network.NetworkModule): self.oft_blocks = weights.w["oft_blocks"] self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] self.num_blocks = self.dim @@ -26,64 +25,57 @@ class NetworkModuleOFT(network.NetworkModule): elif "Conv" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_channels - self.constraint = self.alpha - #self.constraint = self.alpha * self.out_dim + self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - - self.R = self.get_weight() - + self.R = self.get_weight(self.oft_blocks) self.apply_to() # replace forward method of original linear rather than replacing the module + # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - def get_weight(self, multiplier=None): - if not multiplier: - multiplier = self.multiplier() - block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + def get_weight(self, oft_blocks, multiplier=None): + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = multiplier * block_R + (1 - multiplier) * I - R = torch.block_diag(*block_R_weighted) + #block_R_weighted = multiplier * block_R + (1 - multiplier) * I + #R = torch.block_diag(*block_R_weighted) + R = torch.block_diag(*block_R) return R def calc_updown(self, orig_weight): - # this works - # R = self.R - self.R = self.get_weight(self.multiplier()) + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # sending R to device causes major deepfrying i.e. just doesn't work - # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + R = self.get_weight(oft_blocks) + self.R = R # if orig_weight.dim() == 4: # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) # else: # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ self.R + updown = orig_weight @ R output_shape = self.oft_blocks.shape - ## this works - # updown = orig_weight @ R - # output_shape = [orig_weight.size(0), R.size(1)] - return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: return x + + # calculating R here is excruciatingly slow #R = self.get_weight().to(x.device, dtype=x.dtype) R = self.R.to(x.device, dtype=x.dtype) + if x.dim() == 4: x = x.permute(0, 2, 3, 1) x = torch.matmul(x, R) -- cgit v1.2.1 From d10c4db57ed08234a7aed5f530f269ff78544ab0 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:52:14 -0700 Subject: style: formatting --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 2af1bc4c..0a87958e 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -37,7 +37,7 @@ class NetworkModuleOFT(network.NetworkModule): def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - + def get_weight(self, oft_blocks, multiplier=None): block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) @@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - + def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: -- cgit v1.2.1 From 0550659ce6e1c37d1ab05cb8a2cb31d499fa552f Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:13:02 -0700 Subject: style: fix ambiguous variable name --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 0a87958e..4e8382c1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -43,8 +43,8 @@ class NetworkModuleOFT(network.NetworkModule): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) #block_R_weighted = multiplier * block_R + (1 - multiplier) * I #R = torch.block_diag(*block_R_weighted) R = torch.block_diag(*block_R) -- cgit v1.2.1 From 2d8c894b274d60a3e3563a2ace23c4ebcea9e652 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:43:31 -0700 Subject: refactor: use forward hook instead of custom forward --- extensions-builtin/Lora/network_oft.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 4e8382c1..8e561ab0 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -36,9 +36,11 @@ class NetworkModuleOFT(network.NetworkModule): # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward - self.org_module[0].forward = self.forward + #self.org_module[0].forward = self.forward + self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): + self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) @@ -66,14 +68,10 @@ class NetworkModuleOFT(network.NetworkModule): output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def forward(self, x, y=None): - x = self.org_forward(x) - if self.multiplier() == 0.0: - return x - - # calculating R here is excruciatingly slow - #R = self.get_weight().to(x.device, dtype=x.dtype) + + def forward_hook(self, module, args, output): + #print(f'Forward hook in {self.network_key} called') + x = output R = self.R.to(x.device, dtype=x.dtype) if x.dim() == 4: @@ -83,3 +81,20 @@ class NetworkModuleOFT(network.NetworkModule): else: x = torch.matmul(x, R) return x + + # def forward(self, x, y=None): + # x = self.org_forward(x) + # if self.multiplier() == 0.0: + # return x + + # # calculating R here is excruciatingly slow + # #R = self.get_weight().to(x.device, dtype=x.dtype) + # R = self.R.to(x.device, dtype=x.dtype) + + # if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + # else: + # x = torch.matmul(x, R) + # return x -- cgit v1.2.1 From 768354772853a1d27a9bf7e41bd6a6e4eac7a9c7 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 14:42:24 -0700 Subject: fix: return orig weights during updown, merge weights before forward --- extensions-builtin/Lora/network_oft.py | 90 ++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 21 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 8e561ab0..f5f32c23 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,5 +1,6 @@ import torch import network +from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -29,23 +30,56 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] + self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) + #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) self.R = self.get_weight(self.oft_blocks) + + self.merged_weight = self.merge_weight() self.apply_to() + self.merged = False + + + def merge_weight(self): + org_sd = self.org_module[0].state_dict() + R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) + if self.org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) + else: + weight = torch.einsum("oi, op -> pi", self.org_weight, R) + org_sd['weight'] = weight + # replace weight + #self.org_module[0].load_state_dict(org_sd) + return weight + pass + + def replace_weight(self, new_weight): + org_sd = self.org_module[0].state_dict() + org_sd['weight'] = new_weight + self.org_module[0].load_state_dict(org_sd) + self.merged = True + + def restore_weight(self): + org_sd = self.org_module[0].state_dict() + org_sd['weight'] = self.org_weight + self.org_module[0].load_state_dict(org_sd) + self.merged = False + # replace forward method of original linear rather than replacing the module # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward #self.org_module[0].forward = self.forward + self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): - self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + new_norm_Q = torch.clamp(norm_Q, max=constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) #block_R_weighted = multiplier * block_R + (1 - multiplier) * I #R = torch.block_diag(*block_R_weighted) @@ -54,33 +88,47 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + #oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = self.get_weight(oft_blocks) - self.R = R + #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + ##self.R = R - # if orig_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - # else: - # weight = torch.einsum("oi, op -> pi", orig_weight, R) + #if orig_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + #else: + # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ R - output_shape = self.oft_blocks.shape + #updown = orig_weight @ R + #updown = weight + updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) + #updown = orig_weight + output_shape = orig_weight.shape + #orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) + def pre_forward_hook(self, module, input): + if not self.merged: + self.replace_weight(self.merged_weight) + + def forward_hook(self, module, args, output): + if self.merged: + pass + #self.restore_weight() #print(f'Forward hook in {self.network_key} called') - x = output - R = self.R.to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + #x = output + #R = self.R.to(x.device, dtype=x.dtype) + + #if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + #else: + # x = torch.matmul(x, R) + #return x # def forward(self, x, y=None): # x = self.org_forward(x) -- cgit v1.2.1 From fce86ab7d75690785f0f5b496f1b3aee922c0ae3 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:03:54 -0700 Subject: fix: support multiplier, no forward pass hook --- extensions-builtin/Lora/network_oft.py | 43 ++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 10 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f5f32c23..e0672ba6 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -32,21 +32,27 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) - self.R = self.get_weight(self.oft_blocks) + init_multiplier = self.multiplier() * self.calc_scale() + self.last_multiplier = init_multiplier + self.R = self.get_weight(self.oft_blocks, init_multiplier) self.merged_weight = self.merge_weight() self.apply_to() self.merged = False + # weights_backup = getattr(self.org_module[0], 'network_weights_backup', None) + # if weights_backup is None: + # self.org_module[0].network_weights_backup = self.org_weight + def merge_weight(self): - org_sd = self.org_module[0].state_dict() + #org_sd = self.org_module[0].state_dict() R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) if self.org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) else: weight = torch.einsum("oi, op -> pi", self.org_weight, R) - org_sd['weight'] = weight + #org_sd['weight'] = weight # replace weight #self.org_module[0].load_state_dict(org_sd) return weight @@ -74,6 +80,7 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): + multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) @@ -81,9 +88,9 @@ class NetworkModuleOFT(network.NetworkModule): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - #block_R_weighted = multiplier * block_R + (1 - multiplier) * I - #R = torch.block_diag(*block_R_weighted) - R = torch.block_diag(*block_R) + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + #R = torch.block_diag(*block_R) return R @@ -93,6 +100,8 @@ class NetworkModuleOFT(network.NetworkModule): #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) ##self.R = R + #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + ##self.R = R #if orig_weight.dim() == 4: # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) #else: @@ -103,19 +112,33 @@ class NetworkModuleOFT(network.NetworkModule): updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) #updown = orig_weight output_shape = orig_weight.shape - #orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) def pre_forward_hook(self, module, input): - if not self.merged: + multiplier = self.multiplier() * self.calc_scale() + if not multiplier==self.last_multiplier or not self.merged: + + #if multiplier != self.last_multiplier or not self.merged: + self.R = self.get_weight(self.oft_blocks, multiplier) + self.last_multiplier = multiplier + self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) + #elif not self.merged: + # self.replace_weight(self.merged_weight) def forward_hook(self, module, args, output): - if self.merged: - pass + pass + #output = output * self.multiplier() * self.calc_scale() + #if len(args) > 0: + # y = args[0] + # output = output + y + #return output + #if self.merged: + # pass #self.restore_weight() #print(f'Forward hook in {self.network_key} called') -- cgit v1.2.1 From 76f5abdbdb739133eff2ccefa36eac62bea3fa08 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:07:45 -0700 Subject: style: cleanup oft --- extensions-builtin/Lora/network_oft.py | 82 +++------------------------------- 1 file changed, 7 insertions(+), 75 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e0672ba6..e462ccb1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,6 +1,5 @@ import torch import network -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -31,33 +30,24 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) - #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) + init_multiplier = self.multiplier() * self.calc_scale() self.last_multiplier = init_multiplier + self.R = self.get_weight(self.oft_blocks, init_multiplier) self.merged_weight = self.merge_weight() self.apply_to() self.merged = False - # weights_backup = getattr(self.org_module[0], 'network_weights_backup', None) - # if weights_backup is None: - # self.org_module[0].network_weights_backup = self.org_weight - - def merge_weight(self): - #org_sd = self.org_module[0].state_dict() R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) if self.org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) else: weight = torch.einsum("oi, op -> pi", self.org_weight, R) - #org_sd['weight'] = weight - # replace weight - #self.org_module[0].load_state_dict(org_sd) return weight - pass - + def replace_weight(self, new_weight): org_sd = self.org_module[0].state_dict() org_sd['weight'] = new_weight @@ -70,9 +60,7 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module[0].load_state_dict(org_sd) self.merged = False - - # replace forward method of original linear rather than replacing the module - # how do we revert this to unload the weights? + # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? def apply_to(self): self.org_forward = self.org_module[0].forward #self.org_module[0].forward = self.forward @@ -90,82 +78,26 @@ class NetworkModuleOFT(network.NetworkModule): block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) - #R = torch.block_diag(*block_R) return R def calc_updown(self, orig_weight): - #oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - - #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - ##self.R = R - - #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - ##self.R = R - #if orig_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - #else: - # weight = torch.einsum("oi, op -> pi", orig_weight, R) - - #updown = orig_weight @ R - #updown = weight updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) - #updown = orig_weight output_shape = orig_weight.shape orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - + def pre_forward_hook(self, module, input): multiplier = self.multiplier() * self.calc_scale() - if not multiplier==self.last_multiplier or not self.merged: - #if multiplier != self.last_multiplier or not self.merged: + if not multiplier==self.last_multiplier or not self.merged: self.R = self.get_weight(self.oft_blocks, multiplier) self.last_multiplier = multiplier self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) - #elif not self.merged: - # self.replace_weight(self.merged_weight) - + def forward_hook(self, module, args, output): pass - #output = output * self.multiplier() * self.calc_scale() - #if len(args) > 0: - # y = args[0] - # output = output + y - #return output - #if self.merged: - # pass - #self.restore_weight() - #print(f'Forward hook in {self.network_key} called') - - #x = output - #R = self.R.to(x.device, dtype=x.dtype) - - #if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - #else: - # x = torch.matmul(x, R) - #return x - - # def forward(self, x, y=None): - # x = self.org_forward(x) - # if self.multiplier() == 0.0: - # return x - - # # calculating R here is excruciatingly slow - # #R = self.get_weight().to(x.device, dtype=x.dtype) - # R = self.R.to(x.device, dtype=x.dtype) - - # if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - # else: - # x = torch.matmul(x, R) - # return x -- cgit v1.2.1 From de8ee92ed88b855098e273f576a27f4789f0693d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:37:17 -0700 Subject: fix: use merge_weight to cache value --- extensions-builtin/Lora/network_oft.py | 57 ++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 17 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e462ccb1..ebe6740c 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,23 +29,27 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) + #self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) init_multiplier = self.multiplier() * self.calc_scale() self.last_multiplier = init_multiplier self.R = self.get_weight(self.oft_blocks, init_multiplier) + self.hooks = [] self.merged_weight = self.merge_weight() - self.apply_to() + + #self.apply_to() + self.applied = False self.merged = False def merge_weight(self): - R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) - if self.org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) + org_weight = self.org_module[0].weight + R = self.R.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - weight = torch.einsum("oi, op -> pi", self.org_weight, R) + weight = torch.einsum("oi, op -> pi", org_weight, R) return weight def replace_weight(self, new_weight): @@ -55,17 +59,29 @@ class NetworkModuleOFT(network.NetworkModule): self.merged = True def restore_weight(self): - org_sd = self.org_module[0].state_dict() - org_sd['weight'] = self.org_weight - self.org_module[0].load_state_dict(org_sd) - self.merged = False + pass + #org_sd = self.org_module[0].state_dict() + #org_sd['weight'] = self.org_weight + #self.org_module[0].load_state_dict(org_sd) + #self.merged = False # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? def apply_to(self): - self.org_forward = self.org_module[0].forward - #self.org_module[0].forward = self.forward - self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) - self.org_module[0].register_forward_hook(self.forward_hook) + if not self.applied: + self.org_forward = self.org_module[0].forward + #self.org_module[0].forward = self.forward + prehook = self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) + hook = self.org_module[0].register_forward_hook(self.forward_hook) + self.hooks.append(prehook) + self.hooks.append(hook) + self.applied = True + + def remove_from(self): + if self.applied: + for hook in self.hooks: + hook.remove() + self.hooks = [] + self.applied = False def get_weight(self, oft_blocks, multiplier=None): multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) @@ -82,14 +98,22 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): + if not self.applied: + self.apply_to() + + self.merged_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) output_shape = orig_weight.shape - orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + orig_weight = self.merged_weight #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) def pre_forward_hook(self, module, input): + #if not self.applied: + # self.apply_to() + multiplier = self.multiplier() * self.calc_scale() if not multiplier==self.last_multiplier or not self.merged: @@ -98,6 +122,5 @@ class NetworkModuleOFT(network.NetworkModule): self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) - def forward_hook(self, module, args, output): - pass + pass \ No newline at end of file -- cgit v1.2.1 From 4a50c9638c3eac860fb05ae603cd61aabf4cd1a9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 08:54:24 -0700 Subject: refactor: remove used OFT functions --- extensions-builtin/Lora/network_oft.py | 82 +++++----------------------------- 1 file changed, 10 insertions(+), 72 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ebe6740c..3034a407 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,98 +29,36 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - #self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) - init_multiplier = self.multiplier() * self.calc_scale() - self.last_multiplier = init_multiplier - - self.R = self.get_weight(self.oft_blocks, init_multiplier) - - self.hooks = [] - self.merged_weight = self.merge_weight() - - #self.apply_to() - self.applied = False - self.merged = False - - def merge_weight(self): - org_weight = self.org_module[0].weight - R = self.R.to(org_weight.device, dtype=org_weight.dtype) + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) return weight - def replace_weight(self, new_weight): - org_sd = self.org_module[0].state_dict() - org_sd['weight'] = new_weight - self.org_module[0].load_state_dict(org_sd) - self.merged = True - - def restore_weight(self): - pass - #org_sd = self.org_module[0].state_dict() - #org_sd['weight'] = self.org_weight - #self.org_module[0].load_state_dict(org_sd) - #self.merged = False - - # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? - def apply_to(self): - if not self.applied: - self.org_forward = self.org_module[0].forward - #self.org_module[0].forward = self.forward - prehook = self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) - hook = self.org_module[0].register_forward_hook(self.forward_hook) - self.hooks.append(prehook) - self.hooks.append(hook) - self.applied = True - - def remove_from(self): - if self.applied: - for hook in self.hooks: - hook.remove() - self.hooks = [] - self.applied = False - def get_weight(self, oft_blocks, multiplier=None): - multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R def calc_updown(self, orig_weight): - if not self.applied: - self.apply_to() - - self.merged_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + R = self.get_weight(self.oft_blocks, self.multiplier()) + merged_weight = self.merge_weight(R, orig_weight) - updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape - orig_weight = self.merged_weight - #output_shape = self.oft_blocks.shape + orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) - - def pre_forward_hook(self, module, input): - #if not self.applied: - # self.apply_to() - - multiplier = self.multiplier() * self.calc_scale() - - if not multiplier==self.last_multiplier or not self.merged: - self.R = self.get_weight(self.oft_blocks, multiplier) - self.last_multiplier = multiplier - self.merged_weight = self.merge_weight() - self.replace_weight(self.merged_weight) - - def forward_hook(self, module, args, output): - pass \ No newline at end of file -- cgit v1.2.1 From 3b8515d2c9abad7f0ccaac0215803716e861ee0e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 09:27:48 -0700 Subject: fix: multiplier applied twice in finalize_updown --- extensions-builtin/Lora/network_oft.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 3034a407..efbdd296 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -54,7 +54,8 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): - R = self.get_weight(self.oft_blocks, self.multiplier()) + multiplier = self.multiplier() * self.calc_scale() + R = self.get_weight(self.oft_blocks, multiplier) merged_weight = self.merge_weight(R, orig_weight) updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight @@ -62,3 +63,23 @@ class NetworkModuleOFT(network.NetworkModule): orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) + + # override to remove the multiplier/scale factor; it's already multiplied in get_weight + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): + #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) + + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + if ex_bias is not None: + ex_bias = ex_bias * self.multiplier() + + return updown, ex_bias -- cgit v1.2.1 From 6523edb8a45d4e09f11f3b4e1d133afa6fb65e53 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 09:31:15 -0700 Subject: style: conform style --- extensions-builtin/Lora/network_oft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index efbdd296..e43c9a1d 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -63,7 +63,7 @@ class NetworkModuleOFT(network.NetworkModule): orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) - + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) -- cgit v1.2.1 From a2fad6ee055f3f4e98e46b6c2d912776fe608214 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:34:27 -0700 Subject: test implementation based on kohaku diag-oft implementation --- extensions-builtin/Lora/network_oft.py | 59 ++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 21 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e43c9a1d..ff61b369 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,5 +1,6 @@ import torch import network +from einops import rearrange class ModuleTypeOFT(network.ModuleType): @@ -30,35 +31,51 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] - def merge_weight(self, R_weight, org_weight): - R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - return weight + # def merge_weight(self, R_weight, org_weight): + # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + # if org_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + # else: + # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + # weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + # ) + # return weight def get_weight(self, oft_blocks, multiplier=None): - constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + # block_Q = oft_blocks - oft_blocks.transpose(1, 2) + # norm_Q = torch.norm(block_Q.flatten()) + # new_norm_Q = torch.clamp(norm_Q, max=constraint) + # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - R = torch.block_diag(*block_R_weighted) + # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + # R = torch.block_diag(*block_R_weighted) + #return R + return self.oft_blocks - return R def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - merged_weight = self.merge_weight(R, orig_weight) - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #R = self.get_weight(self.oft_blocks, multiplier) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + #merged_weight = self.merge_weight(R, orig_weight) + + orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + orig_weight + ) + weight = rearrange(weight, 'k m ... -> (k m) ...') + + #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight -- cgit v1.2.1 From d727ddfccdc6d474767be9dc3bf504150e81a8a5 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:13:11 -0700 Subject: no idea what i'm doing, trying to support both type of OFT, kblueleaf diag_oft has MultiheadAttn which kohya's doesn't?, attempt create new module based off network_lora.py, errors about tensor dim mismatch --- extensions-builtin/Lora/network_oft.py | 192 +++++++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 47 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff61b369..e102eafc 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,11 +1,12 @@ import torch import network from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["oft_blocks"]): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): return NetworkModuleOFT(net, weights) return None @@ -16,66 +17,117 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) - self.oft_blocks = weights.w["oft_blocks"] - self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] - self.num_blocks = self.dim - - if "Linear" in self.sd_module.__class__.__name__: + self.lin_module = None + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + self.dim = self.oft_blocks.shape[0] + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # alpha is rank if alpha is 0 or None + if self.alpha is None: + pass + self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + else: + raise ValueError("oft_blocks or oft_diag must be in weights dict") + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + if is_linear: self.out_dim = self.sd_module.out_features - elif "Conv" in self.sd_module.__class__.__name__: + #elif hasattr(self.sd_module, "embed_dim"): + # self.out_dim = self.sd_module.embed_dim + #else: + # raise ValueError("Linear sd_module must have out_features or embed_dim") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + elif is_conv: self.out_dim = self.sd_module.out_channels + else: + raise ValueError("sd_module must be Linear or Conv") + - self.constraint = self.alpha * self.out_dim - self.block_size = self.out_dim // self.num_blocks + if self.is_kohya: + self.num_blocks = self.dim + self.block_size = self.out_dim // self.num_blocks + self.constraint = self.alpha * self.out_dim + #elif is_linear or is_conv: + else: + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.constraint = None self.org_module: list[torch.Module] = [self.sd_module] - # def merge_weight(self, R_weight, org_weight): - # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - # if org_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - # else: - # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - # weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - # ) - # return weight + # if is_other_linear: + # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) + # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + # with torch.no_grad(): + # if weight.shape != module.weight.shape: + # weight = weight.reshape(module.weight.shape) + # module.weight.copy_(weight) + # module.to(device=devices.cpu, dtype=devices.dtype) + # module.weight.requires_grad_(False) + # self.lin_module = module + #return module + + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + else: + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + #weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + #) + return weight def get_weight(self, oft_blocks, multiplier=None): - # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + if self.constraint is not None: + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - # block_Q = oft_blocks - oft_blocks.transpose(1, 2) - # norm_Q = torch.norm(block_Q.flatten()) - # new_norm_Q = torch.clamp(norm_Q, max=constraint) - # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + if self.constraint is not None: + new_norm_Q = torch.clamp(norm_Q, max=constraint) + else: + new_norm_Q = norm_Q + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - # R = torch.block_diag(*block_R_weighted) - #return R - return self.oft_blocks + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + return R + #return self.oft_blocks def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - #R = self.get_weight(self.oft_blocks, multiplier) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - #merged_weight = self.merge_weight(R, orig_weight) - - orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - orig_weight - ) - weight = rearrange(weight, 'k m ... -> (k m) ...') - - #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + + #if self.lin_module is not None: + # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) + # weight = torch.mul(torch.mul(R, multiplier), orig_weight) + #else: + # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + # weight = torch.einsum( + # 'k n m, k n ... -> k m ...', + # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + # orig_weight + # ) + # weight = rearrange(weight, 'k m ... -> (k m) ...') + + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight @@ -100,3 +152,49 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + -- cgit v1.2.1 From fe1967a4c4a02eccfa45b65ee19a5b0773ced31c Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:52:55 -0700 Subject: skip multihead attn for now --- extensions-builtin/Lora/network_oft.py | 54 +++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 17 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e102eafc..979a2047 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule): # alpha is rank if alpha is 0 or None if self.alpha is None: pass - self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] else: raise ValueError("oft_blocks or oft_diag must be in weights dict") @@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule): # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim + #self.org_weight = self.org_module[0].weight +# if hasattr(self.sd_module, "in_proj_weight"): +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] +# if hasattr(self.sd_module, "out_proj_weight"): +# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: @@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = self.alpha * self.out_dim #elif is_linear or is_conv: else: - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - self.org_module: list[torch.Module] = [self.sd_module] # if is_other_linear: # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) @@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + if self.is_kohya and not is_other_linear: + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + elif not self.is_kohya and not is_other_linear: + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + else: + # skip for now + updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) + output_shape = (orig_weight.shape[1], orig_weight.shape[1]) #if self.lin_module is not None: # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) # weight = torch.mul(torch.mul(R, multiplier), orig_weight) #else: - # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - # weight = torch.einsum( - # 'k n m, k n ... -> k m ...', - # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - # orig_weight - # ) - # weight = rearrange(weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape + orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.1 From f6c8201e5663ca2182a66c8eca63ce4801d52849 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 19:35:15 -0700 Subject: refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb --- extensions-builtin/Lora/network_oft.py | 131 ++++++++------------------------- 1 file changed, 30 insertions(+), 101 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 979a2047..2be67fe5 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,7 +1,7 @@ import torch import network +from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -11,7 +11,8 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -19,6 +20,7 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -37,61 +39,31 @@ class NetworkModuleOFT(network.NetworkModule): is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + if is_linear: self.out_dim = self.sd_module.out_features - #elif hasattr(self.sd_module, "embed_dim"): - # self.out_dim = self.sd_module.embed_dim - #else: - # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim - #self.org_weight = self.org_module[0].weight -# if hasattr(self.sd_module, "in_proj_weight"): -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] -# if hasattr(self.sd_module, "out_proj_weight"): -# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: raise ValueError("sd_module must be Linear or Conv") - if self.is_kohya: self.num_blocks = self.dim self.block_size = self.out_dim // self.num_blocks self.constraint = self.alpha * self.out_dim - #elif is_linear or is_conv: else: self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - - # if is_other_linear: - # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) - # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - # with torch.no_grad(): - # if weight.shape != module.weight.shape: - # weight = weight.reshape(module.weight.shape) - # module.weight.copy_(weight) - # module.to(device=devices.cpu, dtype=devices.dtype) - # module.weight.requires_grad_(False) - # self.lin_module = module - #return module - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - #weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - #) return weight def get_weight(self, oft_blocks, multiplier=None): @@ -111,48 +83,51 @@ class NetworkModuleOFT(network.NetworkModule): block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R - #return self.oft_blocks + def calc_updown_kohya(self, orig_weight, multiplier): + R = self.get_weight(self.oft_blocks, multiplier) + merged_weight = self.merge_weight(R, orig_weight) - def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - if self.is_kohya and not is_other_linear: - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) - elif not self.is_kohya and not is_other_linear: + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + orig_weight = orig_weight + return self.finalize_updown(updown, orig_weight, output_shape) + + def calc_updown_kb(self, orig_weight, multiplier): + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + + if not is_other_linear: if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - merged_weight + merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) - #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: - # skip for now + # FIXME: skip MultiheadAttention for now updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - #if self.lin_module is not None: - # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - # weight = torch.mul(torch.mul(R, multiplier), orig_weight) - #else: - - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) + def calc_updown(self, orig_weight): + multiplier = self.multiplier() * self.calc_scale() + if self.is_kohya: + return self.calc_updown_kohya(orig_weight, multiplier) + else: + return self.calc_updown_kb(orig_weight, multiplier) + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) @@ -172,49 +147,3 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias - -# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py -def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: - ''' - return a tuple of two value of input dimension decomposed by the number closest to factor - second value is higher or equal than first value. - - In LoRA with Kroneckor Product, first value is a value for weight scale. - secon value is a value for weight. - - Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. - - examples) - factor - -1 2 4 8 16 ... - 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 - 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 - 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 - 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 - 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 - 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 - ''' - - if factor > 0 and (dimension % factor) == 0: - m = factor - n = dimension // factor - if m > n: - n, m = m, n - return m, n - if factor < 0: - factor = dimension - m, n = 1, dimension - length = m + n - while m length or new_m>factor: - break - else: - m, n = new_m, new_n - if m > n: - n, m = m, n - return m, n - -- cgit v1.2.1 From 329c8bacce706811776e1c1c6a0d39b46886a268 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 4 Nov 2023 14:54:36 -0700 Subject: refactor: use same updown for both kohya OFT and LyCORIS diag-oft --- extensions-builtin/Lora/network_oft.py | 91 +++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 17 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 2be67fe5..e4aa082b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -2,6 +2,7 @@ import torch import network from lyco_helpers import factorization from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -24,12 +25,14 @@ class NetworkModuleOFT(network.NetworkModule): # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True - self.oft_blocks = weights.w["oft_blocks"] + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] + self.dim = self.oft_blocks.shape[0] # lora dim + #self.oft_blocks = rearrange(self.oft_blocks, 'k m ... -> (k m) ...') elif "oft_diag" in weights.w.keys(): self.is_kohya = False - self.oft_blocks = weights.w["oft_diag"] + self.oft_blocks = weights.w["oft_diag"] # (num_blocks, block_size, block_size) + # alpha is rank if alpha is 0 or None if self.alpha is None: pass @@ -51,12 +54,57 @@ class NetworkModuleOFT(network.NetworkModule): raise ValueError("sd_module must be Linear or Conv") if self.is_kohya: - self.num_blocks = self.dim - self.block_size = self.out_dim // self.num_blocks + #self.num_blocks = self.dim + #self.block_size = self.out_dim // self.num_blocks + #self.block_size = self.dim + #self.num_blocks = self.out_dim // self.block_size self.constraint = self.alpha * self.out_dim + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) else: - self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) + + if is_other_linear: + self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True) + + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) + + if weight is None and none_ok: + return None + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + + with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) @@ -77,7 +125,8 @@ class NetworkModuleOFT(network.NetworkModule): else: new_norm_Q = norm_Q block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + m_I = torch.eye(self.num_blocks, device=oft_blocks.device).unsqueeze(0).repeat(self.block_size, 1, 1) + #m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I @@ -97,25 +146,33 @@ class NetworkModuleOFT(network.NetworkModule): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - orig_weight=orig_weight.permute(1, 0) + #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + # orig_weight=orig_weight.permute(1, 0) + + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + + # without this line the results are significantly worse / less accurate + oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) + + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + R, merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - orig_weight=orig_weight.permute(1, 0) + #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + # orig_weight=orig_weight.permute(1, 0) updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: # FIXME: skip MultiheadAttention for now + #up = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) @@ -123,10 +180,10 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - if self.is_kohya: - return self.calc_updown_kohya(orig_weight, multiplier) - else: - return self.calc_updown_kb(orig_weight, multiplier) + #if self.is_kohya: + # return self.calc_updown_kohya(orig_weight, multiplier) + #else: + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): -- cgit v1.2.1 From bbf00a96afb2215f13cc72a7908225ae300c423d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 4 Nov 2023 14:56:47 -0700 Subject: refactor: remove unused function --- extensions-builtin/Lora/network_oft.py | 47 ---------------------------------- 1 file changed, 47 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e4aa082b..93402bb2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -2,7 +2,6 @@ import torch import network from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -54,58 +53,12 @@ class NetworkModuleOFT(network.NetworkModule): raise ValueError("sd_module must be Linear or Conv") if self.is_kohya: - #self.num_blocks = self.dim - #self.block_size = self.out_dim // self.num_blocks - #self.block_size = self.dim - #self.num_blocks = self.out_dim // self.block_size self.constraint = self.alpha * self.out_dim self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - if is_other_linear: - self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True) - - - def create_module(self, weights, key, none_ok=False): - weight = weights.get(key) - - if weight is None and none_ok: - return None - - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] - is_conv = type(self.sd_module) in [torch.nn.Conv2d] - - if is_linear: - weight = weight.reshape(weight.shape[0], -1) - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif is_conv and key == "lora_down.weight" or key == "dyn_up": - if len(weight.shape) == 2: - weight = weight.reshape(weight.shape[0], -1, 1, 1) - - if weight.shape[2] != 1 or weight.shape[3] != 1: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - else: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif is_conv and key == "lora_mid.weight": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - elif is_conv and key == "lora_up.weight" or key == "dyn_down": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - else: - raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - - with torch.no_grad(): - if weight.shape != module.weight.shape: - weight = weight.reshape(module.weight.shape) - module.weight.copy_(weight) - - module.to(device=devices.cpu, dtype=devices.dtype) - module.weight.requires_grad_(False) - - return module - - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: -- cgit v1.2.1 From d6d0b22e6657fc84039e82ee735a57101bfe7c17 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 15 Nov 2023 03:08:50 -0800 Subject: fix: ignore calc_scale() for COFT which has very small alpha --- extensions-builtin/Lora/network_oft.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 93402bb2..c45a8d23 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # without this line the results are significantly worse / less accurate + # ensure skew-symmetric matrix oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) @@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: @@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule): return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - #if self.is_kohya: - # return self.calc_updown_kohya(orig_weight, multiplier) - #else: + # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it + #multiplier = self.multiplier() * self.calc_scale() + multiplier = self.multiplier() + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight -- cgit v1.2.1 From eb667e715ad3eea981f6263c143ab0422e5340c9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:28:48 -0800 Subject: feat: LyCORIS/kohya OFT network support --- extensions-builtin/Lora/network_oft.py | 108 ++++++++------------------------- 1 file changed, 26 insertions(+), 82 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index c45a8d23..05c37811 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -11,8 +11,8 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py -# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -25,117 +25,61 @@ class NetworkModuleOFT(network.NetworkModule): if "oft_blocks" in weights.w.keys(): self.is_kohya = True self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) - self.alpha = weights.w["alpha"] + self.alpha = weights.w["alpha"] # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - #self.oft_blocks = rearrange(self.oft_blocks, 'k m ... -> (k m) ...') + # LyCORIS elif "oft_diag" in weights.w.keys(): self.is_kohya = False - self.oft_blocks = weights.w["oft_diag"] # (num_blocks, block_size, block_size) - - # alpha is rank if alpha is 0 or None - if self.alpha is None: - pass - self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] - else: - raise ValueError("oft_blocks or oft_diag must be in weights dict") + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported if is_linear: self.out_dim = self.sd_module.out_features - elif is_other_linear: - self.out_dim = self.sd_module.embed_dim elif is_conv: self.out_dim = self.sd_module.out_channels - else: - raise ValueError("sd_module must be Linear or Conv") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim if self.is_kohya: self.constraint = self.alpha * self.out_dim - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def merge_weight(self, R_weight, org_weight): - R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - return weight - - def get_weight(self, oft_blocks, multiplier=None): - if self.constraint is not None: - constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - if self.constraint is not None: - new_norm_Q = torch.clamp(norm_Q, max=constraint) - else: - new_norm_Q = norm_Q - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.num_blocks, device=oft_blocks.device).unsqueeze(0).repeat(self.block_size, 1, 1) - #m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + def calc_updown_kb(self, orig_weight, multiplier): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix - block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - R = torch.block_diag(*block_R_weighted) - return R + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - def calc_updown_kohya(self, orig_weight, multiplier): - R = self.get_weight(self.oft_blocks, multiplier) - merged_weight = self.merge_weight(R, orig_weight) + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) - - def calc_updown_kb(self, orig_weight, multiplier): - is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] - - if not is_other_linear: - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - - # ensure skew-symmetric matrix - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) - - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape - else: - # FIXME: skip MultiheadAttention for now - #up = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) - output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it - #multiplier = self.multiplier() * self.calc_scale() + # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) - if self.bias is not None: updown = updown.reshape(self.bias.shape) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) -- cgit v1.2.1