From b75b004fe62826455f1aa77e849e7da13902cb17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 23:13:55 +0300 Subject: lora extension rework to include other types of networks --- extensions-builtin/Lora/network_lora.py | 70 +++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 extensions-builtin/Lora/network_lora.py (limited to 'extensions-builtin/Lora/network_lora.py') diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py new file mode 100644 index 00000000..b2d96537 --- /dev/null +++ b/extensions-builtin/Lora/network_lora.py @@ -0,0 +1,70 @@ +import torch + +import network +from modules import devices + + +class ModuleTypeLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): + return NetworkModuleLora(net, weights) + + return None + + +class NetworkModuleLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.up = self.create_module(weights.w["lora_up.weight"]) + self.down = self.create_module(weights.w["lora_down.weight"]) + self.alpha = weights.w["alpha"] if "alpha" in weights.w else None + + def create_module(self, weight, none_ok=False): + if weight is None and none_ok: + return None + + if type(self.sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) + else: + print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + return None + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + + def calc_updown(self, target): + up = self.up.weight.to(target.device, dtype=target.dtype) + down = self.down.weight.to(target.device, dtype=target.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + + return updown + + def forward(self, x, y): + self.up.to(device=devices.device) + self.down.to(device=devices.device) + + return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + + -- cgit v1.2.1 From 238adeaffb037dedbcefe41e7fd4814a1f17baa2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 09:00:47 +0300 Subject: support specifying te and unet weights separately update lora code support full module --- extensions-builtin/Lora/network_lora.py | 72 ++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 28 deletions(-) (limited to 'extensions-builtin/Lora/network_lora.py') diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index b2d96537..26c0a72c 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -1,5 +1,6 @@ import torch +import lyco_helpers import network from modules import devices @@ -16,29 +17,42 @@ class NetworkModuleLora(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) - self.up = self.create_module(weights.w["lora_up.weight"]) - self.down = self.create_module(weights.w["lora_down.weight"]) - self.alpha = weights.w["alpha"] if "alpha" in weights.w else None + self.up_model = self.create_module(weights.w, "lora_up.weight") + self.down_model = self.create_module(weights.w, "lora_down.weight") + self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) + + self.dim = weights.w["lora_down.weight"].shape[0] + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) - def create_module(self, weight, none_ok=False): if weight is None and none_ok: return None - if type(self.sd_module) == torch.nn.Linear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.MultiheadAttention: + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: - print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - return None + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) module.weight.copy_(weight) module.to(device=devices.cpu, dtype=devices.dtype) @@ -46,25 +60,27 @@ class NetworkModuleLora(network.NetworkModule): return module - def calc_updown(self, target): - up = self.up.weight.to(target.device, dtype=target.dtype) - down = self.down.weight.to(target.device, dtype=target.dtype) + def calc_updown(self, orig_weight): + up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + output_shape = [up.size(0), down.size(1)] + if self.mid_model is not None: + # cp-decomposition + mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] else: - updown = up @ down - - updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) - return updown + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y): - self.up.to(device=devices.device) - self.down.to(device=devices.device) + self.up_model.to(device=devices.device) + self.down_model.to(device=devices.device) - return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() -- cgit v1.2.1