aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/lyco_helpers.py47
-rw-r--r--extensions-builtin/Lora/network_oft.py98
-rw-r--r--extensions-builtin/Lora/networks.py7
3 files changed, 135 insertions, 17 deletions
diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py
index 279b34bc..1679a0ce 100644
--- a/extensions-builtin/Lora/lyco_helpers.py
+++ b/extensions-builtin/Lora/lyco_helpers.py
@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
+
+
+# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
+def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
+ '''
+ return a tuple of two value of input dimension decomposed by the number closest to factor
+ second value is higher or equal than first value.
+
+ In LoRA with Kroneckor Product, first value is a value for weight scale.
+ secon value is a value for weight.
+
+ Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
+
+ examples)
+ factor
+ -1 2 4 8 16 ...
+ 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
+ 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
+ 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
+ 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
+ 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
+ 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
+ '''
+
+ if factor > 0 and (dimension % factor) == 0:
+ m = factor
+ n = dimension // factor
+ if m > n:
+ n, m = m, n
+ return m, n
+ if factor < 0:
+ factor = dimension
+ m, n = 1, dimension
+ length = m + n
+ while m<n:
+ new_m = m + 1
+ while dimension%new_m != 0:
+ new_m += 1
+ new_n = dimension // new_m
+ if new_m + new_n > length or new_m>factor:
+ break
+ else:
+ m, n = new_m, new_n
+ if m > n:
+ n, m = m, n
+ return m, n
+
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index e43c9a1d..2be67fe5 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -1,34 +1,62 @@
import torch
import network
+from lyco_helpers import factorization
+from einops import rearrange
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
- if all(x in weights.w for x in ["oft_blocks"]):
+ if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
return NetworkModuleOFT(net, weights)
return None
-# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
+# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
+# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
- self.oft_blocks = weights.w["oft_blocks"]
- self.alpha = weights.w["alpha"]
- self.dim = self.oft_blocks.shape[0]
- self.num_blocks = self.dim
+ self.lin_module = None
+ self.org_module: list[torch.Module] = [self.sd_module]
+
+ # kohya-ss
+ if "oft_blocks" in weights.w.keys():
+ self.is_kohya = True
+ self.oft_blocks = weights.w["oft_blocks"]
+ self.alpha = weights.w["alpha"]
+ self.dim = self.oft_blocks.shape[0]
+ elif "oft_diag" in weights.w.keys():
+ self.is_kohya = False
+ self.oft_blocks = weights.w["oft_diag"]
+ # alpha is rank if alpha is 0 or None
+ if self.alpha is None:
+ pass
+ self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
+ else:
+ raise ValueError("oft_blocks or oft_diag must be in weights dict")
- if "Linear" in self.sd_module.__class__.__name__:
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d]
+ is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
+
+ if is_linear:
self.out_dim = self.sd_module.out_features
- elif "Conv" in self.sd_module.__class__.__name__:
+ elif is_other_linear:
+ self.out_dim = self.sd_module.embed_dim
+ elif is_conv:
self.out_dim = self.sd_module.out_channels
+ else:
+ raise ValueError("sd_module must be Linear or Conv")
- self.constraint = self.alpha * self.out_dim
- self.block_size = self.out_dim // self.num_blocks
-
- self.org_module: list[torch.Module] = [self.sd_module]
+ if self.is_kohya:
+ self.num_blocks = self.dim
+ self.block_size = self.out_dim // self.num_blocks
+ self.constraint = self.alpha * self.out_dim
+ else:
+ self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
+ self.constraint = None
def merge_weight(self, R_weight, org_weight):
R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
@@ -39,31 +67,67 @@ class NetworkModuleOFT(network.NetworkModule):
return weight
def get_weight(self, oft_blocks, multiplier=None):
- constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
+ if self.constraint is not None:
+ constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
- new_norm_Q = torch.clamp(norm_Q, max=constraint)
+ if self.constraint is not None:
+ new_norm_Q = torch.clamp(norm_Q, max=constraint)
+ else:
+ new_norm_Q = norm_Q
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I
R = torch.block_diag(*block_R_weighted)
-
return R
- def calc_updown(self, orig_weight):
- multiplier = self.multiplier() * self.calc_scale()
+ def calc_updown_kohya(self, orig_weight, multiplier):
R = self.get_weight(self.oft_blocks, multiplier)
merged_weight = self.merge_weight(R, orig_weight)
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
orig_weight = orig_weight
+ return self.finalize_updown(updown, orig_weight, output_shape)
+
+ def calc_updown_kb(self, orig_weight, multiplier):
+ is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
+
+ if not is_other_linear:
+ if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
+ orig_weight=orig_weight.permute(1, 0)
+
+ R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
+ merged_weight = torch.einsum(
+ 'k n m, k n ... -> k m ...',
+ R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
+ merged_weight
+ )
+ merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
+
+ if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
+ orig_weight=orig_weight.permute(1, 0)
+
+ updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
+ output_shape = orig_weight.shape
+ else:
+ # FIXME: skip MultiheadAttention for now
+ updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
+ output_shape = (orig_weight.shape[1], orig_weight.shape[1])
return self.finalize_updown(updown, orig_weight, output_shape)
+ def calc_updown(self, orig_weight):
+ multiplier = self.multiplier() * self.calc_scale()
+ if self.is_kohya:
+ return self.calc_updown_kohya(orig_weight, multiplier)
+ else:
+ return self.calc_updown_kb(orig_weight, multiplier)
+
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
#return super().finalize_updown(updown, orig_weight, output_shape, ex_bias)
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 78a97033..7f814706 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -191,10 +191,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # kohya_ss OFT module
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # KohakuBlueLeaf OFT module
+ if sd_module is None and "oft_diag" in key:
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
if sd_module is None:
keys_failed_to_match[key_network] = key
continue