From aab385d01b4311726127397552d791f4d71b7147 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 3 Sep 2023 11:56:02 +0900 Subject: thread safe extra network list_items --- extensions-builtin/Lora/ui_extra_networks_lora.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 55409a78..e9f30062 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -66,11 +66,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): return item def list_items(self): - for index, name in enumerate(networks.available_networks): - item = self.create_item(name, index) - - if item is not None: - yield item + with self.thread_lock: + for index, name in enumerate(networks.available_networks): + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat] -- cgit v1.2.1 From 25de9a785cc9e93c16626db6ab5b16824443de53 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 16:56:19 +0900 Subject: Revert "thread safe extra network list_items" This reverts commit aab385d01b4311726127397552d791f4d71b7147. --- extensions-builtin/Lora/ui_extra_networks_lora.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index e9f30062..55409a78 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -66,11 +66,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): return item def list_items(self): - with self.thread_lock: - for index, name in enumerate(networks.available_networks): - item = self.create_item(name, index) - if item is not None: - yield item + for index, name in enumerate(networks.available_networks): + item = self.create_item(name, index) + + if item is not None: + yield item def allowed_directories_for_previews(self): return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat] -- cgit v1.2.1 From f5959c1c3022c454de22fab749d0f06ab3219868 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 17:05:50 +0900 Subject: thread safe extra network using list --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 55409a78..e74daa77 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -66,7 +66,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): return item def list_items(self): - for index, name in enumerate(networks.available_networks): + names = list(networks.available_networks) + for index, name in enumerate(names): item = self.create_item(name, index) if item is not None: -- cgit v1.2.1 From e785402b6acca12108e15224ff80d58817ab3c27 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 17:28:06 +0900 Subject: return nothing if not found --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index e74daa77..dac90a86 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) + if lora_on_disk is None: + return path, ext = os.path.splitext(lora_on_disk.filename) @@ -69,7 +71,6 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): names = list(networks.available_networks) for index, name in enumerate(names): item = self.create_item(name, index) - if item is not None: yield item -- cgit v1.2.1 From 74b80e72115af46bf1c04167a30f9ec5025cb464 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 12 Sep 2023 09:29:07 +0900 Subject: add comment --- extensions-builtin/Lora/ui_extra_networks_lora.py | 1 + 1 file changed, 1 insertion(+) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index dac90a86..df02c663 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -68,6 +68,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): return item def list_items(self): + # instantiate a list to protect against concurrent modification names = list(networks.available_networks) for index, name in enumerate(names): item = self.create_item(name, index) -- cgit v1.2.1 From ec718f76b58b183859ed732e11ec748c41a13f76 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 17 Oct 2023 23:35:50 -0700 Subject: wip incorrect OFT implementation --- extensions-builtin/Lora/network_oft.py | 82 ++++++++++++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 5 +++ 2 files changed, 87 insertions(+) create mode 100644 extensions-builtin/Lora/network_oft.py (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py new file mode 100644 index 00000000..9ddb175c --- /dev/null +++ b/extensions-builtin/Lora/network_oft.py @@ -0,0 +1,82 @@ +import torch +import network + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]): + return NetworkModuleOFT(net, weights) + + return None + +# adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + + self.dim = self.oft_blocks.shape[0] + self.num_blocks = self.dim + + #if type(self.alpha) == torch.Tensor: + # self.alpha = self.alpha.detach().numpy() + + if "Linear" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_features + elif "Conv" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_channels + + self.constraint = self.alpha * self.out_dim + self.block_size = self.out_dim // self.num_blocks + + self.oft_multiplier = self.multiplier() + + # replace forward method of original linear rather than replacing the module + # self.org_forward = self.sd_module.forward + # self.sd_module.forward = self.forward + + def get_weight(self): + block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + + return R + + def calc_updown(self, orig_weight): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) + # W = R*W_0 + updown = orig_weight + R + output_shape = [R.size(0), orig_weight.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) + + # def forward(self, x, y=None): + # x = self.org_forward(x) + # if self.oft_multiplier == 0.0: + # return x + + # R = self.get_weight().to(x.device, dtype=x.dtype) + # if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + # else: + # x = torch.matmul(x, R) + # return x diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4..bd1f1b75 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -11,6 +11,7 @@ import network_ia3 import network_lokr import network_full import network_norm +import network_oft import torch from typing import Union @@ -28,6 +29,7 @@ module_types = [ network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), network_glora.ModuleTypeGLora(), + network_oft.ModuleTypeOFT(), ] @@ -183,6 +185,9 @@ def load_network(name, network_on_disk): elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: -- cgit v1.2.1 From 1c6efdbba774d603c592debaccd6f5ad827bd1b2 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:16:01 -0700 Subject: inference working but SLOW --- extensions-builtin/Lora/network_oft.py | 73 +++++++++++++++++----------------- extensions-builtin/Lora/networks.py | 42 +++++++++++++++++-- 2 files changed, 75 insertions(+), 40 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 9ddb175c..f085eca5 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -12,6 +12,7 @@ class ModuleTypeOFT(network.ModuleType): # adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) self.oft_blocks = weights.w["oft_blocks"] @@ -20,24 +21,29 @@ class NetworkModuleOFT(network.NetworkModule): self.dim = self.oft_blocks.shape[0] self.num_blocks = self.dim - #if type(self.alpha) == torch.Tensor: - # self.alpha = self.alpha.detach().numpy() - if "Linear" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_features elif "Conv" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_channels - self.constraint = self.alpha * self.out_dim + self.constraint = self.alpha + #self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.num_blocks - self.oft_multiplier = self.multiplier() + self.org_module: list[torch.Module] = [self.sd_module] + + self.R = self.get_weight() - # replace forward method of original linear rather than replacing the module - # self.org_forward = self.sd_module.forward - # self.sd_module.forward = self.forward + self.apply_to() + + # replace forward method of original linear rather than replacing the module + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward - def get_weight(self): + def get_weight(self, multiplier=None): + if not multiplier: + multiplier = self.multiplier() block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) @@ -45,38 +51,31 @@ class NetworkModuleOFT(network.NetworkModule): I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) return R def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I - R = torch.block_diag(*block_R_weighted) - #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) - # W = R*W_0 - updown = orig_weight + R - output_shape = [R.size(0), orig_weight.size(1)] + R = self.R + if orig_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + else: + weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R + output_shape = [orig_weight.size(0), R.size(1)] + #output_shape = [R.size(0), orig_weight.size(1)] return self.finalize_updown(updown, orig_weight, output_shape) - # def forward(self, x, y=None): - # x = self.org_forward(x) - # if self.oft_multiplier == 0.0: - # return x - - # R = self.get_weight().to(x.device, dtype=x.dtype) - # if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - # else: - # x = torch.matmul(x, R) - # return x + def forward(self, x, y=None): + x = self.org_forward(x) + if self.multiplier() == 0.0: + return x + R = self.get_weight().to(x.device, dtype=x.dtype) + if x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = torch.matmul(x, R) + x = x.permute(0, 3, 1, 2) + else: + x = torch.matmul(x, R) + return x diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index bd1f1b75..e5e73450 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -169,6 +169,10 @@ def load_network(name, network_on_disk): else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict + + #if key_network_without_network_parts == "oft_unet": + # print(key_network_without_network_parts) + # pass key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -185,15 +189,39 @@ def load_network(name, network_on_disk): elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) - elif sd_module is None and "oft_unet" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + # UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] + # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] + # TODO: Change matchedm odules based on whether all linear, conv, etc + + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + #key_no_suffix = key.rsplit("_to_", 1)[0] + ## Match all modules of class CrossAttention + #replace_module_list = [] + #for module_type in UNET_TARGET_REPLACE_MODULE_ATTN_ONLY: + # replace_module_list += [module for k, module in shared.sd_model.network_layer_mapping.items() if module_type in module.__class__.__name__] + + #matched_module = replace_module_list.get(key_no_suffix, None) + #if key.endswith('to_q'): + # sd_module = matched_module.to_q or None + #if key.endswith('to_k'): + # sd_module = matched_module.to_k or None + #if key.endswith('to_v'): + # sd_module = matched_module.to_v or None + #if key.endswith('to_out_0'): + # sd_module = matched_module.to_out[0] or None + #if key.endswith('to_out_1'): + # sd_module = matched_module.to_out[1] or None + + if sd_module is None: keys_failed_to_match[key_network] = key continue @@ -214,6 +242,14 @@ def load_network(name, network_on_disk): raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module + + # replaces forward method of original Linear + # applied_to_count = 0 + #for key, created_module in net.modules.items(): + # if isinstance(created_module, network_oft.NetworkModuleOFT): + # net_module.apply_to() + #applied_to_count += 1 + # print(f'Applied OFT modules: {applied_to_count}') embeddings = {} for emb_name, data in bundle_embeddings.items(): -- cgit v1.2.1 From 853e21d98eada4db9a9fd1ae8eda90cf763e2818 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:27:44 -0700 Subject: faster by using cached R in forward --- extensions-builtin/Lora/network_oft.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f085eca5..68efb1db 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,21 +57,32 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): + # this works R = self.R + + # this causes major deepfrying i.e. just doesn't work + # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + if orig_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", orig_weight, R) else: weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R - output_shape = [orig_weight.size(0), R.size(1)] - #output_shape = [R.size(0), orig_weight.size(1)] + output_shape = self.oft_blocks.shape + + ## this works + # updown = orig_weight @ R + # output_shape = [orig_weight.size(0), R.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: return x - R = self.get_weight().to(x.device, dtype=x.dtype) + #R = self.get_weight().to(x.device, dtype=x.dtype) + R = self.R.to(x.device, dtype=x.dtype) if x.dim() == 4: x = x.permute(0, 2, 3, 1) x = torch.matmul(x, R) -- cgit v1.2.1 From eb01d7f0e0fb46285985803296a25715165fb3f9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:56:53 -0700 Subject: faster by calculating R in updown and using cached R in forward --- extensions-builtin/Lora/network_oft.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 68efb1db..fd5b0c0f 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -58,17 +58,18 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): # this works - R = self.R + # R = self.R + self.R = self.get_weight(self.multiplier()) - # this causes major deepfrying i.e. just doesn't work + # sending R to device causes major deepfrying i.e. just doesn't work # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - if orig_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - else: - weight = torch.einsum("oi, op -> pi", orig_weight, R) + # if orig_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + # else: + # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ R + updown = orig_weight @ self.R output_shape = self.oft_blocks.shape ## this works -- cgit v1.2.1 From 7c128bbdac0da1767c239174e91af6f327845372 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:56:17 +0800 Subject: Add fp8 for sd unet --- extensions-builtin/Lora/network.py | 2 +- extensions-builtin/Lora/network_full.py | 4 ++-- extensions-builtin/Lora/network_glora.py | 10 +++++----- extensions-builtin/Lora/network_hada.py | 12 ++++++------ extensions-builtin/Lora/network_ia3.py | 2 +- extensions-builtin/Lora/network_lokr.py | 18 +++++++++--------- extensions-builtin/Lora/network_lora.py | 6 +++--- extensions-builtin/Lora/network_norm.py | 4 ++-- extensions-builtin/Lora/networks.py | 6 +++--- 9 files changed, 32 insertions(+), 32 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8d..a62e5eff 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ class NetworkModule: def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) updown = updown.reshape(output_shape) if len(output_shape) == 4: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e9..f221c95f 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(orig_weight.device) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d4870..efe5c681 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695..d95a0fd1 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): self.t2 = weights.w.get("hada_t2") def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(orig_weight.device) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc4249..96faeaf3 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): self.on_input = weights.w["on_input"].item() def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + w = self.w.to(orig_weight.device) output_shape = [w.size(0), orig_weight.size(1)] if self.on_input: diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab..fcdaeafd 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): def calc_updown(self, orig_weight): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(orig_weight.device) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(orig_weight.device) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c..4cc40295 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): return module def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(orig_weight.device) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce450158..d25afcbb 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4..8ea4ea60 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown + self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias) + self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias += ex_bias + self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 -- cgit v1.2.1 From 321680ccd0e0404223fbdf4f26498f7d0317fb75 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:41:17 -0700 Subject: refactor: fix constraint, re-use get_weight --- extensions-builtin/Lora/network_oft.py | 40 ++++++++++++++-------------------- 1 file changed, 16 insertions(+), 24 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fd5b0c0f..2af1bc4c 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -9,7 +9,7 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -17,7 +17,6 @@ class NetworkModuleOFT(network.NetworkModule): self.oft_blocks = weights.w["oft_blocks"] self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] self.num_blocks = self.dim @@ -26,64 +25,57 @@ class NetworkModuleOFT(network.NetworkModule): elif "Conv" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_channels - self.constraint = self.alpha - #self.constraint = self.alpha * self.out_dim + self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - - self.R = self.get_weight() - + self.R = self.get_weight(self.oft_blocks) self.apply_to() # replace forward method of original linear rather than replacing the module + # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - def get_weight(self, multiplier=None): - if not multiplier: - multiplier = self.multiplier() - block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + def get_weight(self, oft_blocks, multiplier=None): + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = multiplier * block_R + (1 - multiplier) * I - R = torch.block_diag(*block_R_weighted) + #block_R_weighted = multiplier * block_R + (1 - multiplier) * I + #R = torch.block_diag(*block_R_weighted) + R = torch.block_diag(*block_R) return R def calc_updown(self, orig_weight): - # this works - # R = self.R - self.R = self.get_weight(self.multiplier()) + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # sending R to device causes major deepfrying i.e. just doesn't work - # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + R = self.get_weight(oft_blocks) + self.R = R # if orig_weight.dim() == 4: # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) # else: # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ self.R + updown = orig_weight @ R output_shape = self.oft_blocks.shape - ## this works - # updown = orig_weight @ R - # output_shape = [orig_weight.size(0), R.size(1)] - return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: return x + + # calculating R here is excruciatingly slow #R = self.get_weight().to(x.device, dtype=x.dtype) R = self.R.to(x.device, dtype=x.dtype) + if x.dim() == 4: x = x.permute(0, 2, 3, 1) x = torch.matmul(x, R) -- cgit v1.2.1 From d10c4db57ed08234a7aed5f530f269ff78544ab0 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:52:14 -0700 Subject: style: formatting --- extensions-builtin/Lora/network_oft.py | 4 ++-- extensions-builtin/Lora/networks.py | 35 ---------------------------------- 2 files changed, 2 insertions(+), 37 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 2af1bc4c..0a87958e 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -37,7 +37,7 @@ class NetworkModuleOFT(network.NetworkModule): def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - + def get_weight(self, oft_blocks, multiplier=None): block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) @@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - + def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index e5e73450..78a97033 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -169,10 +169,6 @@ def load_network(name, network_on_disk): else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict - - #if key_network_without_network_parts == "oft_unet": - # print(key_network_without_network_parts) - # pass key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -196,31 +192,8 @@ def load_network(name, network_on_disk): sd_module = shared.sd_model.network_layer_mapping.get(key, None) elif sd_module is None and "oft_unet" in key_network_without_network_parts: - # UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] - # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] - # TODO: Change matchedm odules based on whether all linear, conv, etc - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) - #key_no_suffix = key.rsplit("_to_", 1)[0] - ## Match all modules of class CrossAttention - #replace_module_list = [] - #for module_type in UNET_TARGET_REPLACE_MODULE_ATTN_ONLY: - # replace_module_list += [module for k, module in shared.sd_model.network_layer_mapping.items() if module_type in module.__class__.__name__] - - #matched_module = replace_module_list.get(key_no_suffix, None) - #if key.endswith('to_q'): - # sd_module = matched_module.to_q or None - #if key.endswith('to_k'): - # sd_module = matched_module.to_k or None - #if key.endswith('to_v'): - # sd_module = matched_module.to_v or None - #if key.endswith('to_out_0'): - # sd_module = matched_module.to_out[0] or None - #if key.endswith('to_out_1'): - # sd_module = matched_module.to_out[1] or None - if sd_module is None: keys_failed_to_match[key_network] = key @@ -242,14 +215,6 @@ def load_network(name, network_on_disk): raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module - - # replaces forward method of original Linear - # applied_to_count = 0 - #for key, created_module in net.modules.items(): - # if isinstance(created_module, network_oft.NetworkModuleOFT): - # net_module.apply_to() - #applied_to_count += 1 - # print(f'Applied OFT modules: {applied_to_count}') embeddings = {} for emb_name, data in bundle_embeddings.items(): -- cgit v1.2.1 From 0550659ce6e1c37d1ab05cb8a2cb31d499fa552f Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:13:02 -0700 Subject: style: fix ambiguous variable name --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 0a87958e..4e8382c1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -43,8 +43,8 @@ class NetworkModuleOFT(network.NetworkModule): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) #block_R_weighted = multiplier * block_R + (1 - multiplier) * I #R = torch.block_diag(*block_R_weighted) R = torch.block_diag(*block_R) -- cgit v1.2.1 From 2d8c894b274d60a3e3563a2ace23c4ebcea9e652 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:43:31 -0700 Subject: refactor: use forward hook instead of custom forward --- extensions-builtin/Lora/network_oft.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 4e8382c1..8e561ab0 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -36,9 +36,11 @@ class NetworkModuleOFT(network.NetworkModule): # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward - self.org_module[0].forward = self.forward + #self.org_module[0].forward = self.forward + self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): + self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) @@ -66,14 +68,10 @@ class NetworkModuleOFT(network.NetworkModule): output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def forward(self, x, y=None): - x = self.org_forward(x) - if self.multiplier() == 0.0: - return x - - # calculating R here is excruciatingly slow - #R = self.get_weight().to(x.device, dtype=x.dtype) + + def forward_hook(self, module, args, output): + #print(f'Forward hook in {self.network_key} called') + x = output R = self.R.to(x.device, dtype=x.dtype) if x.dim() == 4: @@ -83,3 +81,20 @@ class NetworkModuleOFT(network.NetworkModule): else: x = torch.matmul(x, R) return x + + # def forward(self, x, y=None): + # x = self.org_forward(x) + # if self.multiplier() == 0.0: + # return x + + # # calculating R here is excruciatingly slow + # #R = self.get_weight().to(x.device, dtype=x.dtype) + # R = self.R.to(x.device, dtype=x.dtype) + + # if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + # else: + # x = torch.matmul(x, R) + # return x -- cgit v1.2.1 From 768354772853a1d27a9bf7e41bd6a6e4eac7a9c7 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 14:42:24 -0700 Subject: fix: return orig weights during updown, merge weights before forward --- extensions-builtin/Lora/network_oft.py | 90 ++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 21 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 8e561ab0..f5f32c23 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,5 +1,6 @@ import torch import network +from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -29,23 +30,56 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] + self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) + #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) self.R = self.get_weight(self.oft_blocks) + + self.merged_weight = self.merge_weight() self.apply_to() + self.merged = False + + + def merge_weight(self): + org_sd = self.org_module[0].state_dict() + R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) + if self.org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) + else: + weight = torch.einsum("oi, op -> pi", self.org_weight, R) + org_sd['weight'] = weight + # replace weight + #self.org_module[0].load_state_dict(org_sd) + return weight + pass + + def replace_weight(self, new_weight): + org_sd = self.org_module[0].state_dict() + org_sd['weight'] = new_weight + self.org_module[0].load_state_dict(org_sd) + self.merged = True + + def restore_weight(self): + org_sd = self.org_module[0].state_dict() + org_sd['weight'] = self.org_weight + self.org_module[0].load_state_dict(org_sd) + self.merged = False + # replace forward method of original linear rather than replacing the module # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward #self.org_module[0].forward = self.forward + self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): - self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + new_norm_Q = torch.clamp(norm_Q, max=constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) #block_R_weighted = multiplier * block_R + (1 - multiplier) * I #R = torch.block_diag(*block_R_weighted) @@ -54,33 +88,47 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + #oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = self.get_weight(oft_blocks) - self.R = R + #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + ##self.R = R - # if orig_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - # else: - # weight = torch.einsum("oi, op -> pi", orig_weight, R) + #if orig_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + #else: + # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ R - output_shape = self.oft_blocks.shape + #updown = orig_weight @ R + #updown = weight + updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) + #updown = orig_weight + output_shape = orig_weight.shape + #orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) + def pre_forward_hook(self, module, input): + if not self.merged: + self.replace_weight(self.merged_weight) + + def forward_hook(self, module, args, output): + if self.merged: + pass + #self.restore_weight() #print(f'Forward hook in {self.network_key} called') - x = output - R = self.R.to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + #x = output + #R = self.R.to(x.device, dtype=x.dtype) + + #if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + #else: + # x = torch.matmul(x, R) + #return x # def forward(self, x, y=None): # x = self.org_forward(x) -- cgit v1.2.1 From fce86ab7d75690785f0f5b496f1b3aee922c0ae3 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:03:54 -0700 Subject: fix: support multiplier, no forward pass hook --- extensions-builtin/Lora/network_oft.py | 43 ++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 10 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f5f32c23..e0672ba6 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -32,21 +32,27 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) - self.R = self.get_weight(self.oft_blocks) + init_multiplier = self.multiplier() * self.calc_scale() + self.last_multiplier = init_multiplier + self.R = self.get_weight(self.oft_blocks, init_multiplier) self.merged_weight = self.merge_weight() self.apply_to() self.merged = False + # weights_backup = getattr(self.org_module[0], 'network_weights_backup', None) + # if weights_backup is None: + # self.org_module[0].network_weights_backup = self.org_weight + def merge_weight(self): - org_sd = self.org_module[0].state_dict() + #org_sd = self.org_module[0].state_dict() R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) if self.org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) else: weight = torch.einsum("oi, op -> pi", self.org_weight, R) - org_sd['weight'] = weight + #org_sd['weight'] = weight # replace weight #self.org_module[0].load_state_dict(org_sd) return weight @@ -74,6 +80,7 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): + multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) @@ -81,9 +88,9 @@ class NetworkModuleOFT(network.NetworkModule): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - #block_R_weighted = multiplier * block_R + (1 - multiplier) * I - #R = torch.block_diag(*block_R_weighted) - R = torch.block_diag(*block_R) + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + #R = torch.block_diag(*block_R) return R @@ -93,6 +100,8 @@ class NetworkModuleOFT(network.NetworkModule): #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) ##self.R = R + #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + ##self.R = R #if orig_weight.dim() == 4: # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) #else: @@ -103,19 +112,33 @@ class NetworkModuleOFT(network.NetworkModule): updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) #updown = orig_weight output_shape = orig_weight.shape - #orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) def pre_forward_hook(self, module, input): - if not self.merged: + multiplier = self.multiplier() * self.calc_scale() + if not multiplier==self.last_multiplier or not self.merged: + + #if multiplier != self.last_multiplier or not self.merged: + self.R = self.get_weight(self.oft_blocks, multiplier) + self.last_multiplier = multiplier + self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) + #elif not self.merged: + # self.replace_weight(self.merged_weight) def forward_hook(self, module, args, output): - if self.merged: - pass + pass + #output = output * self.multiplier() * self.calc_scale() + #if len(args) > 0: + # y = args[0] + # output = output + y + #return output + #if self.merged: + # pass #self.restore_weight() #print(f'Forward hook in {self.network_key} called') -- cgit v1.2.1 From 76f5abdbdb739133eff2ccefa36eac62bea3fa08 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:07:45 -0700 Subject: style: cleanup oft --- extensions-builtin/Lora/network_oft.py | 82 +++------------------------------- 1 file changed, 7 insertions(+), 75 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e0672ba6..e462ccb1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,6 +1,5 @@ import torch import network -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -31,33 +30,24 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) - #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) + init_multiplier = self.multiplier() * self.calc_scale() self.last_multiplier = init_multiplier + self.R = self.get_weight(self.oft_blocks, init_multiplier) self.merged_weight = self.merge_weight() self.apply_to() self.merged = False - # weights_backup = getattr(self.org_module[0], 'network_weights_backup', None) - # if weights_backup is None: - # self.org_module[0].network_weights_backup = self.org_weight - - def merge_weight(self): - #org_sd = self.org_module[0].state_dict() R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) if self.org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) else: weight = torch.einsum("oi, op -> pi", self.org_weight, R) - #org_sd['weight'] = weight - # replace weight - #self.org_module[0].load_state_dict(org_sd) return weight - pass - + def replace_weight(self, new_weight): org_sd = self.org_module[0].state_dict() org_sd['weight'] = new_weight @@ -70,9 +60,7 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module[0].load_state_dict(org_sd) self.merged = False - - # replace forward method of original linear rather than replacing the module - # how do we revert this to unload the weights? + # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? def apply_to(self): self.org_forward = self.org_module[0].forward #self.org_module[0].forward = self.forward @@ -90,82 +78,26 @@ class NetworkModuleOFT(network.NetworkModule): block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) - #R = torch.block_diag(*block_R) return R def calc_updown(self, orig_weight): - #oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - - #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - ##self.R = R - - #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - ##self.R = R - #if orig_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - #else: - # weight = torch.einsum("oi, op -> pi", orig_weight, R) - - #updown = orig_weight @ R - #updown = weight updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) - #updown = orig_weight output_shape = orig_weight.shape orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - + def pre_forward_hook(self, module, input): multiplier = self.multiplier() * self.calc_scale() - if not multiplier==self.last_multiplier or not self.merged: - #if multiplier != self.last_multiplier or not self.merged: + if not multiplier==self.last_multiplier or not self.merged: self.R = self.get_weight(self.oft_blocks, multiplier) self.last_multiplier = multiplier self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) - #elif not self.merged: - # self.replace_weight(self.merged_weight) - + def forward_hook(self, module, args, output): pass - #output = output * self.multiplier() * self.calc_scale() - #if len(args) > 0: - # y = args[0] - # output = output + y - #return output - #if self.merged: - # pass - #self.restore_weight() - #print(f'Forward hook in {self.network_key} called') - - #x = output - #R = self.R.to(x.device, dtype=x.dtype) - - #if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - #else: - # x = torch.matmul(x, R) - #return x - - # def forward(self, x, y=None): - # x = self.org_forward(x) - # if self.multiplier() == 0.0: - # return x - - # # calculating R here is excruciatingly slow - # #R = self.get_weight().to(x.device, dtype=x.dtype) - # R = self.R.to(x.device, dtype=x.dtype) - - # if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - # else: - # x = torch.matmul(x, R) - # return x -- cgit v1.2.1 From de8ee92ed88b855098e273f576a27f4789f0693d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:37:17 -0700 Subject: fix: use merge_weight to cache value --- extensions-builtin/Lora/network_oft.py | 57 ++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 17 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e462ccb1..ebe6740c 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,23 +29,27 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) + #self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) init_multiplier = self.multiplier() * self.calc_scale() self.last_multiplier = init_multiplier self.R = self.get_weight(self.oft_blocks, init_multiplier) + self.hooks = [] self.merged_weight = self.merge_weight() - self.apply_to() + + #self.apply_to() + self.applied = False self.merged = False def merge_weight(self): - R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) - if self.org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) + org_weight = self.org_module[0].weight + R = self.R.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - weight = torch.einsum("oi, op -> pi", self.org_weight, R) + weight = torch.einsum("oi, op -> pi", org_weight, R) return weight def replace_weight(self, new_weight): @@ -55,17 +59,29 @@ class NetworkModuleOFT(network.NetworkModule): self.merged = True def restore_weight(self): - org_sd = self.org_module[0].state_dict() - org_sd['weight'] = self.org_weight - self.org_module[0].load_state_dict(org_sd) - self.merged = False + pass + #org_sd = self.org_module[0].state_dict() + #org_sd['weight'] = self.org_weight + #self.org_module[0].load_state_dict(org_sd) + #self.merged = False # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? def apply_to(self): - self.org_forward = self.org_module[0].forward - #self.org_module[0].forward = self.forward - self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) - self.org_module[0].register_forward_hook(self.forward_hook) + if not self.applied: + self.org_forward = self.org_module[0].forward + #self.org_module[0].forward = self.forward + prehook = self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) + hook = self.org_module[0].register_forward_hook(self.forward_hook) + self.hooks.append(prehook) + self.hooks.append(hook) + self.applied = True + + def remove_from(self): + if self.applied: + for hook in self.hooks: + hook.remove() + self.hooks = [] + self.applied = False def get_weight(self, oft_blocks, multiplier=None): multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) @@ -82,14 +98,22 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): + if not self.applied: + self.apply_to() + + self.merged_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) output_shape = orig_weight.shape - orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + orig_weight = self.merged_weight #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) def pre_forward_hook(self, module, input): + #if not self.applied: + # self.apply_to() + multiplier = self.multiplier() * self.calc_scale() if not multiplier==self.last_multiplier or not self.merged: @@ -98,6 +122,5 @@ class NetworkModuleOFT(network.NetworkModule): self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) - def forward_hook(self, module, args, output): - pass + pass \ No newline at end of file -- cgit v1.2.1 From 4a50c9638c3eac860fb05ae603cd61aabf4cd1a9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 08:54:24 -0700 Subject: refactor: remove used OFT functions --- extensions-builtin/Lora/network_oft.py | 82 +++++----------------------------- 1 file changed, 10 insertions(+), 72 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ebe6740c..3034a407 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,98 +29,36 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - #self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) - init_multiplier = self.multiplier() * self.calc_scale() - self.last_multiplier = init_multiplier - - self.R = self.get_weight(self.oft_blocks, init_multiplier) - - self.hooks = [] - self.merged_weight = self.merge_weight() - - #self.apply_to() - self.applied = False - self.merged = False - - def merge_weight(self): - org_weight = self.org_module[0].weight - R = self.R.to(org_weight.device, dtype=org_weight.dtype) + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) return weight - def replace_weight(self, new_weight): - org_sd = self.org_module[0].state_dict() - org_sd['weight'] = new_weight - self.org_module[0].load_state_dict(org_sd) - self.merged = True - - def restore_weight(self): - pass - #org_sd = self.org_module[0].state_dict() - #org_sd['weight'] = self.org_weight - #self.org_module[0].load_state_dict(org_sd) - #self.merged = False - - # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? - def apply_to(self): - if not self.applied: - self.org_forward = self.org_module[0].forward - #self.org_module[0].forward = self.forward - prehook = self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) - hook = self.org_module[0].register_forward_hook(self.forward_hook) - self.hooks.append(prehook) - self.hooks.append(hook) - self.applied = True - - def remove_from(self): - if self.applied: - for hook in self.hooks: - hook.remove() - self.hooks = [] - self.applied = False - def get_weight(self, oft_blocks, multiplier=None): - multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R def calc_updown(self, orig_weight): - if not self.applied: - self.apply_to() - - self.merged_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + R = self.get_weight(self.oft_blocks, self.multiplier()) + merged_weight = self.merge_weight(R, orig_weight) - updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape - orig_weight = self.merged_weight - #output_shape = self.oft_blocks.shape + orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) - - def pre_forward_hook(self, module, input): - #if not self.applied: - # self.apply_to() - - multiplier = self.multiplier() * self.calc_scale() - - if not multiplier==self.last_multiplier or not self.merged: - self.R = self.get_weight(self.oft_blocks, multiplier) - self.last_multiplier = multiplier - self.merged_weight = self.merge_weight() - self.replace_weight(self.merged_weight) - - def forward_hook(self, module, args, output): - pass \ No newline at end of file -- cgit v1.2.1 From 3b8515d2c9abad7f0ccaac0215803716e861ee0e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 09:27:48 -0700 Subject: fix: multiplier applied twice in finalize_updown --- extensions-builtin/Lora/network_oft.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 3034a407..efbdd296 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -54,7 +54,8 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): - R = self.get_weight(self.oft_blocks, self.multiplier()) + multiplier = self.multiplier() * self.calc_scale() + R = self.get_weight(self.oft_blocks, multiplier) merged_weight = self.merge_weight(R, orig_weight) updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight @@ -62,3 +63,23 @@ class NetworkModuleOFT(network.NetworkModule): orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) + + # override to remove the multiplier/scale factor; it's already multiplied in get_weight + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): + #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) + + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + if ex_bias is not None: + ex_bias = ex_bias * self.multiplier() + + return updown, ex_bias -- cgit v1.2.1 From 6523edb8a45d4e09f11f3b4e1d133afa6fb65e53 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 09:31:15 -0700 Subject: style: conform style --- extensions-builtin/Lora/network_oft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index efbdd296..e43c9a1d 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -63,7 +63,7 @@ class NetworkModuleOFT(network.NetworkModule): orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) - + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) -- cgit v1.2.1 From a2fad6ee055f3f4e98e46b6c2d912776fe608214 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:34:27 -0700 Subject: test implementation based on kohaku diag-oft implementation --- extensions-builtin/Lora/network_oft.py | 59 ++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 21 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e43c9a1d..ff61b369 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,5 +1,6 @@ import torch import network +from einops import rearrange class ModuleTypeOFT(network.ModuleType): @@ -30,35 +31,51 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] - def merge_weight(self, R_weight, org_weight): - R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - return weight + # def merge_weight(self, R_weight, org_weight): + # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + # if org_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + # else: + # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + # weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + # ) + # return weight def get_weight(self, oft_blocks, multiplier=None): - constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + # block_Q = oft_blocks - oft_blocks.transpose(1, 2) + # norm_Q = torch.norm(block_Q.flatten()) + # new_norm_Q = torch.clamp(norm_Q, max=constraint) + # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - R = torch.block_diag(*block_R_weighted) + # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + # R = torch.block_diag(*block_R_weighted) + #return R + return self.oft_blocks - return R def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - merged_weight = self.merge_weight(R, orig_weight) - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #R = self.get_weight(self.oft_blocks, multiplier) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + #merged_weight = self.merge_weight(R, orig_weight) + + orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + orig_weight + ) + weight = rearrange(weight, 'k m ... -> (k m) ...') + + #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight -- cgit v1.2.1 From 65ccd6305fcf72347d5ed68f03095dced865ef6e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:11:32 -0700 Subject: detect diag_oft type --- extensions-builtin/Lora/networks.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 78a97033..7f814706 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -191,10 +191,17 @@ def load_network(name, network_on_disk): key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # kohya_ss OFT module elif sd_module is None and "oft_unet" in key_network_without_network_parts: key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # KohakuBlueLeaf OFT module + if sd_module is None and "oft_diag" in key: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_network] = key continue -- cgit v1.2.1 From d727ddfccdc6d474767be9dc3bf504150e81a8a5 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:13:11 -0700 Subject: no idea what i'm doing, trying to support both type of OFT, kblueleaf diag_oft has MultiheadAttn which kohya's doesn't?, attempt create new module based off network_lora.py, errors about tensor dim mismatch --- extensions-builtin/Lora/network_oft.py | 192 +++++++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 47 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff61b369..e102eafc 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,11 +1,12 @@ import torch import network from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["oft_blocks"]): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): return NetworkModuleOFT(net, weights) return None @@ -16,66 +17,117 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) - self.oft_blocks = weights.w["oft_blocks"] - self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] - self.num_blocks = self.dim - - if "Linear" in self.sd_module.__class__.__name__: + self.lin_module = None + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + self.dim = self.oft_blocks.shape[0] + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # alpha is rank if alpha is 0 or None + if self.alpha is None: + pass + self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + else: + raise ValueError("oft_blocks or oft_diag must be in weights dict") + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + if is_linear: self.out_dim = self.sd_module.out_features - elif "Conv" in self.sd_module.__class__.__name__: + #elif hasattr(self.sd_module, "embed_dim"): + # self.out_dim = self.sd_module.embed_dim + #else: + # raise ValueError("Linear sd_module must have out_features or embed_dim") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + elif is_conv: self.out_dim = self.sd_module.out_channels + else: + raise ValueError("sd_module must be Linear or Conv") + - self.constraint = self.alpha * self.out_dim - self.block_size = self.out_dim // self.num_blocks + if self.is_kohya: + self.num_blocks = self.dim + self.block_size = self.out_dim // self.num_blocks + self.constraint = self.alpha * self.out_dim + #elif is_linear or is_conv: + else: + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.constraint = None self.org_module: list[torch.Module] = [self.sd_module] - # def merge_weight(self, R_weight, org_weight): - # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - # if org_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - # else: - # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - # weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - # ) - # return weight + # if is_other_linear: + # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) + # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + # with torch.no_grad(): + # if weight.shape != module.weight.shape: + # weight = weight.reshape(module.weight.shape) + # module.weight.copy_(weight) + # module.to(device=devices.cpu, dtype=devices.dtype) + # module.weight.requires_grad_(False) + # self.lin_module = module + #return module + + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + else: + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + #weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + #) + return weight def get_weight(self, oft_blocks, multiplier=None): - # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + if self.constraint is not None: + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - # block_Q = oft_blocks - oft_blocks.transpose(1, 2) - # norm_Q = torch.norm(block_Q.flatten()) - # new_norm_Q = torch.clamp(norm_Q, max=constraint) - # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + if self.constraint is not None: + new_norm_Q = torch.clamp(norm_Q, max=constraint) + else: + new_norm_Q = norm_Q + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - # R = torch.block_diag(*block_R_weighted) - #return R - return self.oft_blocks + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + return R + #return self.oft_blocks def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - #R = self.get_weight(self.oft_blocks, multiplier) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - #merged_weight = self.merge_weight(R, orig_weight) - - orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - orig_weight - ) - weight = rearrange(weight, 'k m ... -> (k m) ...') - - #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + + #if self.lin_module is not None: + # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) + # weight = torch.mul(torch.mul(R, multiplier), orig_weight) + #else: + # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + # weight = torch.einsum( + # 'k n m, k n ... -> k m ...', + # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + # orig_weight + # ) + # weight = rearrange(weight, 'k m ... -> (k m) ...') + + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight @@ -100,3 +152,49 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + -- cgit v1.2.1 From fe1967a4c4a02eccfa45b65ee19a5b0773ced31c Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:52:55 -0700 Subject: skip multihead attn for now --- extensions-builtin/Lora/network_oft.py | 54 +++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 17 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e102eafc..979a2047 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule): # alpha is rank if alpha is 0 or None if self.alpha is None: pass - self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] else: raise ValueError("oft_blocks or oft_diag must be in weights dict") @@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule): # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim + #self.org_weight = self.org_module[0].weight +# if hasattr(self.sd_module, "in_proj_weight"): +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] +# if hasattr(self.sd_module, "out_proj_weight"): +# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: @@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = self.alpha * self.out_dim #elif is_linear or is_conv: else: - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - self.org_module: list[torch.Module] = [self.sd_module] # if is_other_linear: # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) @@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + if self.is_kohya and not is_other_linear: + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + elif not self.is_kohya and not is_other_linear: + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + else: + # skip for now + updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) + output_shape = (orig_weight.shape[1], orig_weight.shape[1]) #if self.lin_module is not None: # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) # weight = torch.mul(torch.mul(R, multiplier), orig_weight) #else: - # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - # weight = torch.einsum( - # 'k n m, k n ... -> k m ...', - # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - # orig_weight - # ) - # weight = rearrange(weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape + orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.1 From f6c8201e5663ca2182a66c8eca63ce4801d52849 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 19:35:15 -0700 Subject: refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb --- extensions-builtin/Lora/lyco_helpers.py | 47 ++++++++++++ extensions-builtin/Lora/network_oft.py | 131 ++++++++------------------------ 2 files changed, 77 insertions(+), 101 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 279b34bc..1679a0ce 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): up = up.reshape(up.size(0), -1) down = down.reshape(down.size(0), -1) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 979a2047..2be67fe5 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,7 +1,7 @@ import torch import network +from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -11,7 +11,8 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -19,6 +20,7 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -37,61 +39,31 @@ class NetworkModuleOFT(network.NetworkModule): is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + if is_linear: self.out_dim = self.sd_module.out_features - #elif hasattr(self.sd_module, "embed_dim"): - # self.out_dim = self.sd_module.embed_dim - #else: - # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim - #self.org_weight = self.org_module[0].weight -# if hasattr(self.sd_module, "in_proj_weight"): -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] -# if hasattr(self.sd_module, "out_proj_weight"): -# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: raise ValueError("sd_module must be Linear or Conv") - if self.is_kohya: self.num_blocks = self.dim self.block_size = self.out_dim // self.num_blocks self.constraint = self.alpha * self.out_dim - #elif is_linear or is_conv: else: self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - - # if is_other_linear: - # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) - # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - # with torch.no_grad(): - # if weight.shape != module.weight.shape: - # weight = weight.reshape(module.weight.shape) - # module.weight.copy_(weight) - # module.to(device=devices.cpu, dtype=devices.dtype) - # module.weight.requires_grad_(False) - # self.lin_module = module - #return module - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - #weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - #) return weight def get_weight(self, oft_blocks, multiplier=None): @@ -111,48 +83,51 @@ class NetworkModuleOFT(network.NetworkModule): block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R - #return self.oft_blocks + def calc_updown_kohya(self, orig_weight, multiplier): + R = self.get_weight(self.oft_blocks, multiplier) + merged_weight = self.merge_weight(R, orig_weight) - def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - if self.is_kohya and not is_other_linear: - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) - elif not self.is_kohya and not is_other_linear: + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + orig_weight = orig_weight + return self.finalize_updown(updown, orig_weight, output_shape) + + def calc_updown_kb(self, orig_weight, multiplier): + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + + if not is_other_linear: if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - merged_weight + merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) - #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: - # skip for now + # FIXME: skip MultiheadAttention for now updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - #if self.lin_module is not None: - # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - # weight = torch.mul(torch.mul(R, multiplier), orig_weight) - #else: - - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) + def calc_updown(self, orig_weight): + multiplier = self.multiplier() * self.calc_scale() + if self.is_kohya: + return self.calc_updown_kohya(orig_weight, multiplier) + else: + return self.calc_updown_kb(orig_weight, multiplier) + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) @@ -172,49 +147,3 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias - -# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py -def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: - ''' - return a tuple of two value of input dimension decomposed by the number closest to factor - second value is higher or equal than first value. - - In LoRA with Kroneckor Product, first value is a value for weight scale. - secon value is a value for weight. - - Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. - - examples) - factor - -1 2 4 8 16 ... - 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 - 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 - 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 - 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 - 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 - 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 - ''' - - if factor > 0 and (dimension % factor) == 0: - m = factor - n = dimension // factor - if m > n: - n, m = m, n - return m, n - if factor < 0: - factor = dimension - m, n = 1, dimension - length = m + n - while m length or new_m>factor: - break - else: - m, n = new_m, new_n - if m > n: - n, m = m, n - return m, n - -- cgit v1.2.1 From 329c8bacce706811776e1c1c6a0d39b46886a268 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 4 Nov 2023 14:54:36 -0700 Subject: refactor: use same updown for both kohya OFT and LyCORIS diag-oft --- extensions-builtin/Lora/network_oft.py | 91 +++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 17 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 2be67fe5..e4aa082b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -2,6 +2,7 @@ import torch import network from lyco_helpers import factorization from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -24,12 +25,14 @@ class NetworkModuleOFT(network.NetworkModule): # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True - self.oft_blocks = weights.w["oft_blocks"] + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] + self.dim = self.oft_blocks.shape[0] # lora dim + #self.oft_blocks = rearrange(self.oft_blocks, 'k m ... -> (k m) ...') elif "oft_diag" in weights.w.keys(): self.is_kohya = False - self.oft_blocks = weights.w["oft_diag"] + self.oft_blocks = weights.w["oft_diag"] # (num_blocks, block_size, block_size) + # alpha is rank if alpha is 0 or None if self.alpha is None: pass @@ -51,12 +54,57 @@ class NetworkModuleOFT(network.NetworkModule): raise ValueError("sd_module must be Linear or Conv") if self.is_kohya: - self.num_blocks = self.dim - self.block_size = self.out_dim // self.num_blocks + #self.num_blocks = self.dim + #self.block_size = self.out_dim // self.num_blocks + #self.block_size = self.dim + #self.num_blocks = self.out_dim // self.block_size self.constraint = self.alpha * self.out_dim + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) else: - self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) + + if is_other_linear: + self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True) + + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) + + if weight is None and none_ok: + return None + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + + with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) @@ -77,7 +125,8 @@ class NetworkModuleOFT(network.NetworkModule): else: new_norm_Q = norm_Q block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + m_I = torch.eye(self.num_blocks, device=oft_blocks.device).unsqueeze(0).repeat(self.block_size, 1, 1) + #m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I @@ -97,25 +146,33 @@ class NetworkModuleOFT(network.NetworkModule): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - orig_weight=orig_weight.permute(1, 0) + #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + # orig_weight=orig_weight.permute(1, 0) + + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + + # without this line the results are significantly worse / less accurate + oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) + + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + R, merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - orig_weight=orig_weight.permute(1, 0) + #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + # orig_weight=orig_weight.permute(1, 0) updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: # FIXME: skip MultiheadAttention for now + #up = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) @@ -123,10 +180,10 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - if self.is_kohya: - return self.calc_updown_kohya(orig_weight, multiplier) - else: - return self.calc_updown_kb(orig_weight, multiplier) + #if self.is_kohya: + # return self.calc_updown_kohya(orig_weight, multiplier) + #else: + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): -- cgit v1.2.1 From bbf00a96afb2215f13cc72a7908225ae300c423d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 4 Nov 2023 14:56:47 -0700 Subject: refactor: remove unused function --- extensions-builtin/Lora/network_oft.py | 47 ---------------------------------- 1 file changed, 47 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e4aa082b..93402bb2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -2,7 +2,6 @@ import torch import network from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -54,58 +53,12 @@ class NetworkModuleOFT(network.NetworkModule): raise ValueError("sd_module must be Linear or Conv") if self.is_kohya: - #self.num_blocks = self.dim - #self.block_size = self.out_dim // self.num_blocks - #self.block_size = self.dim - #self.num_blocks = self.out_dim // self.block_size self.constraint = self.alpha * self.out_dim self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - if is_other_linear: - self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True) - - - def create_module(self, weights, key, none_ok=False): - weight = weights.get(key) - - if weight is None and none_ok: - return None - - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] - is_conv = type(self.sd_module) in [torch.nn.Conv2d] - - if is_linear: - weight = weight.reshape(weight.shape[0], -1) - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif is_conv and key == "lora_down.weight" or key == "dyn_up": - if len(weight.shape) == 2: - weight = weight.reshape(weight.shape[0], -1, 1, 1) - - if weight.shape[2] != 1 or weight.shape[3] != 1: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - else: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif is_conv and key == "lora_mid.weight": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - elif is_conv and key == "lora_up.weight" or key == "dyn_down": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - else: - raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - - with torch.no_grad(): - if weight.shape != module.weight.shape: - weight = weight.reshape(module.weight.shape) - module.weight.copy_(weight) - - module.to(device=devices.cpu, dtype=devices.dtype) - module.weight.requires_grad_(False) - - return module - - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: -- cgit v1.2.1 From d6d0b22e6657fc84039e82ee735a57101bfe7c17 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 15 Nov 2023 03:08:50 -0800 Subject: fix: ignore calc_scale() for COFT which has very small alpha --- extensions-builtin/Lora/network_oft.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 93402bb2..c45a8d23 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # without this line the results are significantly worse / less accurate + # ensure skew-symmetric matrix oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) @@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: @@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule): return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - #if self.is_kohya: - # return self.calc_updown_kohya(orig_weight, multiplier) - #else: + # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it + #multiplier = self.multiplier() * self.calc_scale() + multiplier = self.multiplier() + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight -- cgit v1.2.1 From eb667e715ad3eea981f6263c143ab0422e5340c9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:28:48 -0800 Subject: feat: LyCORIS/kohya OFT network support --- extensions-builtin/Lora/network_oft.py | 108 ++++++++------------------------- 1 file changed, 26 insertions(+), 82 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index c45a8d23..05c37811 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -11,8 +11,8 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py -# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -25,117 +25,61 @@ class NetworkModuleOFT(network.NetworkModule): if "oft_blocks" in weights.w.keys(): self.is_kohya = True self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) - self.alpha = weights.w["alpha"] + self.alpha = weights.w["alpha"] # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - #self.oft_blocks = rearrange(self.oft_blocks, 'k m ... -> (k m) ...') + # LyCORIS elif "oft_diag" in weights.w.keys(): self.is_kohya = False - self.oft_blocks = weights.w["oft_diag"] # (num_blocks, block_size, block_size) - - # alpha is rank if alpha is 0 or None - if self.alpha is None: - pass - self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] - else: - raise ValueError("oft_blocks or oft_diag must be in weights dict") + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported if is_linear: self.out_dim = self.sd_module.out_features - elif is_other_linear: - self.out_dim = self.sd_module.embed_dim elif is_conv: self.out_dim = self.sd_module.out_channels - else: - raise ValueError("sd_module must be Linear or Conv") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim if self.is_kohya: self.constraint = self.alpha * self.out_dim - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def merge_weight(self, R_weight, org_weight): - R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - return weight - - def get_weight(self, oft_blocks, multiplier=None): - if self.constraint is not None: - constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - if self.constraint is not None: - new_norm_Q = torch.clamp(norm_Q, max=constraint) - else: - new_norm_Q = norm_Q - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.num_blocks, device=oft_blocks.device).unsqueeze(0).repeat(self.block_size, 1, 1) - #m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + def calc_updown_kb(self, orig_weight, multiplier): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix - block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - R = torch.block_diag(*block_R_weighted) - return R + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - def calc_updown_kohya(self, orig_weight, multiplier): - R = self.get_weight(self.oft_blocks, multiplier) - merged_weight = self.merge_weight(R, orig_weight) + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) - - def calc_updown_kb(self, orig_weight, multiplier): - is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] - - if not is_other_linear: - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - - # ensure skew-symmetric matrix - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) - - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape - else: - # FIXME: skip MultiheadAttention for now - #up = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) - output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it - #multiplier = self.multiplier() * self.calc_scale() + # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) - if self.bias is not None: updown = updown.reshape(self.bias.shape) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) -- cgit v1.2.1 From 370a77f8e78e65a8a1339289d684cb43df142f70 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:59:34 +0800 Subject: Option for using fp16 weight when apply lora --- extensions-builtin/Lora/networks.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 0170dbfb..d22ed843 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): - updown, ex_bias = module.calc_updown(self.weight) + if getattr(self, 'fp16_weight', None) is None: + weight = self.weight + bias = self.bias + else: + weight = self.fp16_weight.clone().to(self.weight.device) + bias = getattr(self, 'fp16_bias', None) + if bias is not None: + bias = bias.clone().to(self.bias.device) + updown, ex_bias = module.calc_updown(weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) + self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) + self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 -- cgit v1.2.1 From 16bdcce92d5b482d50cdc32a8f308040d320b6c9 Mon Sep 17 00:00:00 2001 From: Rene Kroon Date: Fri, 8 Dec 2023 21:19:29 +0100 Subject: #13354: solve lora loading issue --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7f814706..629bf853 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -159,7 +159,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) -- cgit v1.2.1 From 735c9e8059384d4f640e5582413c30871f83eac5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:38:32 +0800 Subject: Fix network_oft --- extensions-builtin/Lora/network_oft.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c37811..44465f7a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -53,12 +53,17 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): + I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -70,15 +75,10 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -94,4 +94,5 @@ class NetworkModuleOFT(network.NetworkModule): if ex_bias is not None: ex_bias = ex_bias * self.multiplier() - return updown, ex_bias + # Ignore calc_scale, which is not used in OFT. + return updown * self.multiplier(), ex_bias -- cgit v1.2.1 From 265bc26c21264d63956e8f30f1ce31dec917fc76 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:43:24 +0800 Subject: Use self.scale instead of custom finalize --- extensions-builtin/Lora/network_oft.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 44465f7a..e3ae61a2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -78,21 +80,3 @@ class NetworkModuleOFT(network.NetworkModule): print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - # Ignore calc_scale, which is not used in OFT. - return updown * self.multiplier(), ex_bias -- cgit v1.2.1 From 8fc67f3851babd4575d3312b931d5e7c2b0c78c6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:44:49 +0800 Subject: remove debug print --- extensions-builtin/Lora/network_oft.py | 1 - 1 file changed, 1 deletion(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e3ae61a2..ff4eb59b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -77,6 +77,5 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.1 From 3772a82a70769fe1aac884a75bf5a3313fb83328 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:47:13 +0800 Subject: better naming and correct order for device. --- extensions-builtin/Lora/network_oft.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff4eb59b..fa647020 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,14 +56,15 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) -- cgit v1.2.1 From 59d060fd5ea93fcc3fdbfbd13b6e20fda06ecf94 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 30 Dec 2023 17:11:03 +0900 Subject: More lora not found warning --- extensions-builtin/Lora/networks.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 985b2753..72ebd624 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import gradio as gr import logging import os import re @@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No emb_db.skipped_embeddings[name] = embedding if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' + sd_hijack.model_hijack.comments.append(lora_not_found_message) + if shared.opts.lora_not_found_warning_console: + print(f'\n{lora_not_found_message}\n') + if shared.opts.lora_not_found_gradio_warning: + gr.Warning(lora_not_found_message) purge_networks_from_memory() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ef23968c..1518f7e5 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), + "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), + "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), })) -- cgit v1.2.1 From bc5ae74c7d8949bab37e260b16e76889b9968099 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 21:52:27 +0100 Subject: Added negative prompts to extra networks lora --- extensions-builtin/Lora/ui_edit_user_metadata.py | 14 ++++++++++++-- extensions-builtin/Lora/ui_extra_networks_lora.py | 9 +++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c7011909..f7859b21 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,12 +54,14 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight + user_metadata["negative text"] = negative_text + user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -127,6 +129,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('negative text', ''), + float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -162,7 +166,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) - + self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") + self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -198,6 +203,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -211,7 +218,10 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, self.edit_notes, ] + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663..09ce2a05 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -45,6 +45,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + negative_prompt = item["user_metadata"].get("negative text") + preferred_negative_weight = item["user_metadata"].get("negative weight") + item["negative_prompt"] = quote_js("") + if negative_prompt: + neg_prompt = negative_prompt + if (preferred_negative_weight > 0): + neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version -- cgit v1.2.1 From a2f23f9d22dde87bf2529dcb2854a6a5d3d44278 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 22:16:51 +0100 Subject: Code Style fixes --- extensions-builtin/Lora/ui_extra_networks_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 09ce2a05..9a6624e3 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -52,8 +52,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): neg_prompt = negative_prompt if (preferred_negative_weight > 0): neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) - + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version -- cgit v1.2.1 From d4945f4422e5a0bf31a6dbe4c1aeedd78c09eacb Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:22:30 +0100 Subject: Removed weight slider for negative prompts --- extensions-builtin/Lora/ui_edit_user_metadata.py | 7 +------ extensions-builtin/Lora/ui_extra_networks_lora.py | 6 +----- 2 files changed, 2 insertions(+), 11 deletions(-) (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index f7859b21..3160aecf 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,14 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["negative text"] = negative_text - user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -130,7 +129,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), user_metadata.get('negative text', ''), - float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -167,7 +165,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") - self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -204,7 +201,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -219,7 +215,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, self.edit_notes, ] diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 9a6624e3..e714fac4 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -46,13 +46,9 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): item["prompt"] += " + " + quote_js(" " + activation_text) negative_prompt = item["user_metadata"].get("negative text") - preferred_negative_weight = item["user_metadata"].get("negative weight") item["negative_prompt"] = quote_js("") if negative_prompt: - neg_prompt = negative_prompt - if (preferred_negative_weight > 0): - neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) + item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: -- cgit v1.2.1