From 735c9e8059384d4f640e5582413c30871f83eac5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:38:32 +0800 Subject: Fix network_oft --- extensions-builtin/Lora/network_oft.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c37811..44465f7a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -53,12 +53,17 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): + I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -70,15 +75,10 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -94,4 +94,5 @@ class NetworkModuleOFT(network.NetworkModule): if ex_bias is not None: ex_bias = ex_bias * self.multiplier() - return updown, ex_bias + # Ignore calc_scale, which is not used in OFT. + return updown * self.multiplier(), ex_bias -- cgit v1.2.1 From 265bc26c21264d63956e8f30f1ce31dec917fc76 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:43:24 +0800 Subject: Use self.scale instead of custom finalize --- extensions-builtin/Lora/network_oft.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 44465f7a..e3ae61a2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -78,21 +80,3 @@ class NetworkModuleOFT(network.NetworkModule): print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - # Ignore calc_scale, which is not used in OFT. - return updown * self.multiplier(), ex_bias -- cgit v1.2.1 From 8fc67f3851babd4575d3312b931d5e7c2b0c78c6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:44:49 +0800 Subject: remove debug print --- extensions-builtin/Lora/network_oft.py | 1 - 1 file changed, 1 deletion(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e3ae61a2..ff4eb59b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -77,6 +77,5 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.1 From 3772a82a70769fe1aac884a75bf5a3313fb83328 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:47:13 +0800 Subject: better naming and correct order for device. --- extensions-builtin/Lora/network_oft.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff4eb59b..fa647020 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,14 +56,15 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) -- cgit v1.2.1 From f8f38c7c28e48f9f79225c969e3e82b1adcfb910 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:31:48 +0800 Subject: Fix dtype casting for OFT module --- extensions-builtin/Lora/network_oft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fa647020..342fcd0d 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = self.oft_blocks.to(orig_weight.device) eye = torch.eye(self.block_size, device=self.oft_blocks.device) if self.is_kohya: @@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = oft_blocks.to(orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.1 From fd383140cf405100f3c619f106472273a7545beb Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 22 Jan 2024 02:52:34 -0800 Subject: fix: wrong devices for eye and constraint --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 342fcd0d..d1c46a4b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,12 +57,12 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device) - eye = torch.eye(self.block_size, device=self.oft_blocks.device) + eye = torch.eye(self.block_size, device=oft_blocks.device) if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) -- cgit v1.2.1 From 92ab0ef7d65ededa758f81e52cf4f48f72d13564 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 19 Feb 2024 10:05:30 +0300 Subject: Merge pull request #14871 from v0xie/boft Support inference with LyCORIS BOFT networks --- extensions-builtin/Lora/network_oft.py | 58 ++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 10 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index d1c46a4b..d658ad10 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -22,6 +22,8 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.scale = 1.0 + self.is_kohya = False + self.is_boft = False # kohya-ss if "oft_blocks" in weights.w.keys(): @@ -29,13 +31,19 @@ class NetworkModuleOFT(network.NetworkModule): self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.alpha = weights.w["alpha"] # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - # LyCORIS + # LyCORIS OFT elif "oft_diag" in weights.w.keys(): - self.is_kohya = False self.oft_blocks = weights.w["oft_diag"] # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + # LyCORIS BOFT + if weights.w["oft_diag"].dim() == 4: + self.is_boft = True + self.rescale = weights.w.get('rescale', None) + if self.rescale is not None: + self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported @@ -51,6 +59,13 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = self.alpha * self.out_dim self.num_blocks = self.dim self.block_size = self.out_dim // self.dim + elif self.is_boft: + self.constraint = None + self.boft_m = weights.w["oft_diag"].shape[0] + self.block_num = weights.w["oft_diag"].shape[1] + self.block_size = weights.w["oft_diag"].shape[2] + self.boft_b = self.block_size + #self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim) else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) @@ -68,14 +83,37 @@ class NetworkModuleOFT(network.NetworkModule): R = oft_blocks.to(orig_weight.device) - # This errors out for MultiheadAttention, might need to be handled up-stream - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if not self.is_boft: + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + else: + # TODO: determine correct value for scale + scale = 1.0 + m = self.boft_m + b = self.boft_b + r_b = b // 2 + inp = orig_weight + for i in range(m): + bi = R[i] # b_num, b_size, b_size + if i == 0: + # Apply multiplier/scale and rescale into first weight + bi = bi * scale + (1 - scale) * eye + inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) + inp = rearrange(inp, "(d b) ... -> d b ...", b=b) + inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) + inp = rearrange(inp, "d b ... -> (d b) ...") + inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) + merged_weight = inp + + # Rescale mechanism + if self.rescale is not None: + merged_weight = self.rescale.to(merged_weight) * merged_weight updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape -- cgit v1.2.1 From a10c8df8761c01801bac60d7977ae7e997ab51b0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 26 Feb 2024 07:12:12 +0300 Subject: Merge pull request #14973 from AUTOMATIC1111/Fix-new-oft-boft Fix the OFT/BOFT bugs when using new LyCORIS implementation --- extensions-builtin/Lora/network_oft.py | 50 ++++++++++++++++------------------ 1 file changed, 24 insertions(+), 26 deletions(-) (limited to 'extensions-builtin/Lora/network_oft.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index d658ad10..7821a8a7 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,6 +1,5 @@ import torch import network -from lyco_helpers import factorization from einops import rearrange @@ -22,24 +21,24 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.scale = 1.0 - self.is_kohya = False + self.is_R = False self.is_boft = False - # kohya-ss + # kohya-ss/New LyCORIS OFT/BOFT if "oft_blocks" in weights.w.keys(): - self.is_kohya = True self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) - self.alpha = weights.w["alpha"] # alpha is constraint + self.alpha = weights.w.get("alpha", None) # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - # LyCORIS OFT + # Old LyCORIS OFT elif "oft_diag" in weights.w.keys(): + self.is_R = True self.oft_blocks = weights.w["oft_diag"] # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) - # LyCORIS BOFT - if weights.w["oft_diag"].dim() == 4: - self.is_boft = True + # LyCORIS BOFT + if self.oft_blocks.dim() == 4: + self.is_boft = True self.rescale = weights.w.get('rescale', None) if self.rescale is not None: self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) @@ -55,30 +54,29 @@ class NetworkModuleOFT(network.NetworkModule): elif is_other_linear: self.out_dim = self.sd_module.embed_dim - if self.is_kohya: - self.constraint = self.alpha * self.out_dim - self.num_blocks = self.dim - self.block_size = self.out_dim // self.dim - elif self.is_boft: + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim + self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim + if self.is_R: self.constraint = None - self.boft_m = weights.w["oft_diag"].shape[0] - self.block_num = weights.w["oft_diag"].shape[1] - self.block_size = weights.w["oft_diag"].shape[2] + self.block_size = self.dim + self.num_blocks = self.out_dim // self.dim + elif self.is_boft: + self.boft_m = self.oft_blocks.shape[0] + self.num_blocks = self.oft_blocks.shape[1] + self.block_size = self.oft_blocks.shape[2] self.boft_b = self.block_size - #self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim) - else: - self.constraint = None - self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device) eye = torch.eye(self.block_size, device=oft_blocks.device) - if self.is_kohya: - block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + if not self.is_R: + block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix + if self.constraint != 0: + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device) -- cgit v1.2.1