aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/network_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/network_lora.py')
-rw-r--r--extensions-builtin/Lora/network_lora.py72
1 files changed, 44 insertions, 28 deletions
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
index b2d96537..26c0a72c 100644
--- a/extensions-builtin/Lora/network_lora.py
+++ b/extensions-builtin/Lora/network_lora.py
@@ -1,5 +1,6 @@
import torch
+import lyco_helpers
import network
from modules import devices
@@ -16,29 +17,42 @@ class NetworkModuleLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
- self.up = self.create_module(weights.w["lora_up.weight"])
- self.down = self.create_module(weights.w["lora_down.weight"])
- self.alpha = weights.w["alpha"] if "alpha" in weights.w else None
+ self.up_model = self.create_module(weights.w, "lora_up.weight")
+ self.down_model = self.create_module(weights.w, "lora_down.weight")
+ self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
+
+ self.dim = weights.w["lora_down.weight"].shape[0]
+
+ def create_module(self, weights, key, none_ok=False):
+ weight = weights.get(key)
- def create_module(self, weight, none_ok=False):
if weight is None and none_ok:
return None
- if type(self.sd_module) == torch.nn.Linear:
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(self.sd_module) == torch.nn.MultiheadAttention:
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d]
+
+ if is_linear:
+ weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
+ elif is_conv and key == "lora_down.weight" or key == "dyn_up":
+ if len(weight.shape) == 2:
+ weight = weight.reshape(weight.shape[0], -1, 1, 1)
+
+ if weight.shape[2] != 1 or weight.shape[3] != 1:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ else:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif is_conv and key == "lora_mid.weight":
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ elif is_conv and key == "lora_up.weight" or key == "dyn_down":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
- elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
- print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
- return None
+ raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
with torch.no_grad():
+ if weight.shape != module.weight.shape:
+ weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)
module.to(device=devices.cpu, dtype=devices.dtype)
@@ -46,25 +60,27 @@ class NetworkModuleLora(network.NetworkModule):
return module
- def calc_updown(self, target):
- up = self.up.weight.to(target.device, dtype=target.dtype)
- down = self.down.weight.to(target.device, dtype=target.dtype)
+ def calc_updown(self, orig_weight):
+ up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
- if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
- updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
- elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
- updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
+ output_shape = [up.size(0), down.size(1)]
+ if self.mid_model is not None:
+ # cp-decomposition
+ mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
+ output_shape += mid.shape[2:]
else:
- updown = up @ down
-
- updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
+ if len(down.shape) == 4:
+ output_shape += down.shape[2:]
+ updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
- return updown
+ return self.finalize_updown(updown, orig_weight, output_shape)
def forward(self, x, y):
- self.up.to(device=devices.device)
- self.down.to(device=devices.device)
+ self.up_model.to(device=devices.device)
+ self.down_model.to(device=devices.device)
- return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
+ return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()