From a2fad6ee055f3f4e98e46b6c2d912776fe608214 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:34:27 -0700 Subject: test implementation based on kohaku diag-oft implementation --- extensions-builtin/Lora/network_oft.py | 59 ++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e43c9a1d..ff61b369 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,5 +1,6 @@ import torch import network +from einops import rearrange class ModuleTypeOFT(network.ModuleType): @@ -30,35 +31,51 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] - def merge_weight(self, R_weight, org_weight): - R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - return weight + # def merge_weight(self, R_weight, org_weight): + # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + # if org_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + # else: + # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + # weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + # ) + # return weight def get_weight(self, oft_blocks, multiplier=None): - constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + # block_Q = oft_blocks - oft_blocks.transpose(1, 2) + # norm_Q = torch.norm(block_Q.flatten()) + # new_norm_Q = torch.clamp(norm_Q, max=constraint) + # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - R = torch.block_diag(*block_R_weighted) + # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + # R = torch.block_diag(*block_R_weighted) + #return R + return self.oft_blocks - return R def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - merged_weight = self.merge_weight(R, orig_weight) - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #R = self.get_weight(self.oft_blocks, multiplier) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + #merged_weight = self.merge_weight(R, orig_weight) + + orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + orig_weight + ) + weight = rearrange(weight, 'k m ... -> (k m) ...') + + #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight -- cgit v1.2.1 From 65ccd6305fcf72347d5ed68f03095dced865ef6e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:11:32 -0700 Subject: detect diag_oft type --- extensions-builtin/Lora/networks.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 78a97033..7f814706 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -191,10 +191,17 @@ def load_network(name, network_on_disk): key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # kohya_ss OFT module elif sd_module is None and "oft_unet" in key_network_without_network_parts: key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # KohakuBlueLeaf OFT module + if sd_module is None and "oft_diag" in key: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_network] = key continue -- cgit v1.2.1 From d727ddfccdc6d474767be9dc3bf504150e81a8a5 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:13:11 -0700 Subject: no idea what i'm doing, trying to support both type of OFT, kblueleaf diag_oft has MultiheadAttn which kohya's doesn't?, attempt create new module based off network_lora.py, errors about tensor dim mismatch --- extensions-builtin/Lora/network_oft.py | 192 +++++++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 47 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff61b369..e102eafc 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,11 +1,12 @@ import torch import network from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["oft_blocks"]): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): return NetworkModuleOFT(net, weights) return None @@ -16,66 +17,117 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) - self.oft_blocks = weights.w["oft_blocks"] - self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] - self.num_blocks = self.dim - - if "Linear" in self.sd_module.__class__.__name__: + self.lin_module = None + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + self.dim = self.oft_blocks.shape[0] + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # alpha is rank if alpha is 0 or None + if self.alpha is None: + pass + self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + else: + raise ValueError("oft_blocks or oft_diag must be in weights dict") + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + if is_linear: self.out_dim = self.sd_module.out_features - elif "Conv" in self.sd_module.__class__.__name__: + #elif hasattr(self.sd_module, "embed_dim"): + # self.out_dim = self.sd_module.embed_dim + #else: + # raise ValueError("Linear sd_module must have out_features or embed_dim") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + elif is_conv: self.out_dim = self.sd_module.out_channels + else: + raise ValueError("sd_module must be Linear or Conv") + - self.constraint = self.alpha * self.out_dim - self.block_size = self.out_dim // self.num_blocks + if self.is_kohya: + self.num_blocks = self.dim + self.block_size = self.out_dim // self.num_blocks + self.constraint = self.alpha * self.out_dim + #elif is_linear or is_conv: + else: + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.constraint = None self.org_module: list[torch.Module] = [self.sd_module] - # def merge_weight(self, R_weight, org_weight): - # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - # if org_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - # else: - # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - # weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - # ) - # return weight + # if is_other_linear: + # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) + # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + # with torch.no_grad(): + # if weight.shape != module.weight.shape: + # weight = weight.reshape(module.weight.shape) + # module.weight.copy_(weight) + # module.to(device=devices.cpu, dtype=devices.dtype) + # module.weight.requires_grad_(False) + # self.lin_module = module + #return module + + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + else: + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + #weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + #) + return weight def get_weight(self, oft_blocks, multiplier=None): - # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + if self.constraint is not None: + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - # block_Q = oft_blocks - oft_blocks.transpose(1, 2) - # norm_Q = torch.norm(block_Q.flatten()) - # new_norm_Q = torch.clamp(norm_Q, max=constraint) - # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + if self.constraint is not None: + new_norm_Q = torch.clamp(norm_Q, max=constraint) + else: + new_norm_Q = norm_Q + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - # R = torch.block_diag(*block_R_weighted) - #return R - return self.oft_blocks + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + return R + #return self.oft_blocks def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - #R = self.get_weight(self.oft_blocks, multiplier) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - #merged_weight = self.merge_weight(R, orig_weight) - - orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - orig_weight - ) - weight = rearrange(weight, 'k m ... -> (k m) ...') - - #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + + #if self.lin_module is not None: + # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) + # weight = torch.mul(torch.mul(R, multiplier), orig_weight) + #else: + # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + # weight = torch.einsum( + # 'k n m, k n ... -> k m ...', + # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + # orig_weight + # ) + # weight = rearrange(weight, 'k m ... -> (k m) ...') + + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight @@ -100,3 +152,49 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + -- cgit v1.2.1 From fe1967a4c4a02eccfa45b65ee19a5b0773ced31c Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:52:55 -0700 Subject: skip multihead attn for now --- extensions-builtin/Lora/network_oft.py | 54 +++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e102eafc..979a2047 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule): # alpha is rank if alpha is 0 or None if self.alpha is None: pass - self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] else: raise ValueError("oft_blocks or oft_diag must be in weights dict") @@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule): # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim + #self.org_weight = self.org_module[0].weight +# if hasattr(self.sd_module, "in_proj_weight"): +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] +# if hasattr(self.sd_module, "out_proj_weight"): +# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: @@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = self.alpha * self.out_dim #elif is_linear or is_conv: else: - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - self.org_module: list[torch.Module] = [self.sd_module] # if is_other_linear: # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) @@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + if self.is_kohya and not is_other_linear: + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + elif not self.is_kohya and not is_other_linear: + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + else: + # skip for now + updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) + output_shape = (orig_weight.shape[1], orig_weight.shape[1]) #if self.lin_module is not None: # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) # weight = torch.mul(torch.mul(R, multiplier), orig_weight) #else: - # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - # weight = torch.einsum( - # 'k n m, k n ... -> k m ...', - # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - # orig_weight - # ) - # weight = rearrange(weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape + orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.1 From f6c8201e5663ca2182a66c8eca63ce4801d52849 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 19:35:15 -0700 Subject: refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb --- extensions-builtin/Lora/lyco_helpers.py | 47 ++++++++++++ extensions-builtin/Lora/network_oft.py | 131 ++++++++------------------------ 2 files changed, 77 insertions(+), 101 deletions(-) diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 279b34bc..1679a0ce 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): up = up.reshape(up.size(0), -1) down = down.reshape(down.size(0), -1) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 979a2047..2be67fe5 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,7 +1,7 @@ import torch import network +from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -11,7 +11,8 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -19,6 +20,7 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -37,61 +39,31 @@ class NetworkModuleOFT(network.NetworkModule): is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + if is_linear: self.out_dim = self.sd_module.out_features - #elif hasattr(self.sd_module, "embed_dim"): - # self.out_dim = self.sd_module.embed_dim - #else: - # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim - #self.org_weight = self.org_module[0].weight -# if hasattr(self.sd_module, "in_proj_weight"): -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] -# if hasattr(self.sd_module, "out_proj_weight"): -# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: raise ValueError("sd_module must be Linear or Conv") - if self.is_kohya: self.num_blocks = self.dim self.block_size = self.out_dim // self.num_blocks self.constraint = self.alpha * self.out_dim - #elif is_linear or is_conv: else: self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - - # if is_other_linear: - # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) - # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - # with torch.no_grad(): - # if weight.shape != module.weight.shape: - # weight = weight.reshape(module.weight.shape) - # module.weight.copy_(weight) - # module.to(device=devices.cpu, dtype=devices.dtype) - # module.weight.requires_grad_(False) - # self.lin_module = module - #return module - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - #weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - #) return weight def get_weight(self, oft_blocks, multiplier=None): @@ -111,48 +83,51 @@ class NetworkModuleOFT(network.NetworkModule): block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R - #return self.oft_blocks + def calc_updown_kohya(self, orig_weight, multiplier): + R = self.get_weight(self.oft_blocks, multiplier) + merged_weight = self.merge_weight(R, orig_weight) - def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - if self.is_kohya and not is_other_linear: - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) - elif not self.is_kohya and not is_other_linear: + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + orig_weight = orig_weight + return self.finalize_updown(updown, orig_weight, output_shape) + + def calc_updown_kb(self, orig_weight, multiplier): + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + + if not is_other_linear: if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - merged_weight + merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) - #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: - # skip for now + # FIXME: skip MultiheadAttention for now updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - #if self.lin_module is not None: - # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - # weight = torch.mul(torch.mul(R, multiplier), orig_weight) - #else: - - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) + def calc_updown(self, orig_weight): + multiplier = self.multiplier() * self.calc_scale() + if self.is_kohya: + return self.calc_updown_kohya(orig_weight, multiplier) + else: + return self.calc_updown_kb(orig_weight, multiplier) + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) @@ -172,49 +147,3 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias - -# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py -def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: - ''' - return a tuple of two value of input dimension decomposed by the number closest to factor - second value is higher or equal than first value. - - In LoRA with Kroneckor Product, first value is a value for weight scale. - secon value is a value for weight. - - Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. - - examples) - factor - -1 2 4 8 16 ... - 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 - 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 - 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 - 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 - 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 - 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 - ''' - - if factor > 0 and (dimension % factor) == 0: - m = factor - n = dimension // factor - if m > n: - n, m = m, n - return m, n - if factor < 0: - factor = dimension - m, n = 1, dimension - length = m + n - while m length or new_m>factor: - break - else: - m, n = new_m, new_n - if m > n: - n, m = m, n - return m, n - -- cgit v1.2.1