aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r--extensions-builtin/Lora/lora.py129
1 files changed, 106 insertions, 23 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 6f246921..7b56136f 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -1,10 +1,9 @@
-import glob
import os
import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors
+from modules import shared, devices, sd_models, errors, scripts
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -93,6 +92,7 @@ class LoraOnDisk:
self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
+ self.alias = self.metadata.get('ss_output_name', self.name)
class LoraModule:
@@ -165,12 +165,14 @@ def load_lora(name, filename):
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(sd_module) == torch.nn.Conv2d:
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -182,7 +184,7 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -199,11 +201,11 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
- loras_on_disk = [available_loras.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
+ if any(x is None for x in loras_on_disk):
list_available_loras()
- loras_on_disk = [available_loras.get(name, None) for name in names]
+ loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
@@ -232,6 +234,8 @@ def lora_calc_updown(lora, module, target):
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
@@ -240,6 +244,19 @@ def lora_calc_updown(lora, module, target):
return updown
+def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ weights_backup = getattr(self, "lora_weights_backup", None)
+
+ if weights_backup is None:
+ return
+
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
@@ -264,12 +281,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
self.lora_weights_backup = weights_backup
if current_names != wanted_names:
- if weights_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
+ lora_restore_weights_from_backup(self)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
@@ -297,15 +309,48 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
+
+
+def lora_forward(module, input, original_forward):
+ """
+ Old way of applying Lora by executing operations during layer's forward.
+ Stacking many loras this way results in big performance degradation.
+ """
+
+ if len(loaded_loras) == 0:
+ return original_forward(module, input)
+
+ input = devices.cond_cast_unet(input)
+
+ lora_restore_weights_from_backup(module)
+ lora_reset_cached_weight(module)
+
+ res = original_forward(module, input)
+
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is None:
+ continue
+
+ module.up.to(device=devices.device)
+ module.down.to(device=devices.device)
+
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ return res
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
@@ -318,6 +363,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
def lora_Conv2d_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
@@ -343,24 +391,59 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
def list_available_loras():
available_loras.clear()
+ available_lora_aliases.clear()
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
- candidates = \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
-
+ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
+ entry = LoraOnDisk(name, filename)
+
+ available_loras[name] = entry
+
+ available_lora_aliases[name] = entry
+ available_lora_aliases[entry.alias] = entry
+
+
+re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+
+
+def infotext_pasted(infotext, params):
+ if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+ return # if the other extension is active, it will handle those fields, no need to do anything
+
+ added = []
+
+ for k in params:
+ if not k.startswith("AddNet Model "):
+ continue
+
+ num = k[13:]
+
+ if params.get("AddNet Module " + num) != "LoRA":
+ continue
+
+ name = params.get("AddNet Model " + num)
+ if name is None:
+ continue
+
+ m = re_lora_name.match(name)
+ if m:
+ name = m.group(1)
+
+ multiplier = params.get("AddNet Weight A " + num, "1.0")
- available_loras[name] = LoraOnDisk(name, filename)
+ added.append(f"<lora:{name}:{multiplier}>")
+ if added:
+ params["Prompt"] += "\n" + "".join(added)
available_loras = {}
+available_lora_aliases = {}
loaded_loras = []
list_available_loras()