aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-13 08:16:37 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-13 08:16:37 +0300
commitb08500cec8a791ef20082628b49b17df833f5dda (patch)
tree2d09f3ca93139f082b88463f3a2a43a4ac45526f /extensions-builtin/Lora
parent5ab7f213bec2f816f9c5644becb32eb72c8ffb89 (diff)
parent231562ea13e4f697953bdbabd6b76b22a88c587b (diff)
Merge branch 'release_candidate'
Diffstat (limited to 'extensions-builtin/Lora')
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py1
-rw-r--r--extensions-builtin/Lora/lora.py116
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py27
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py2
4 files changed, 128 insertions, 18 deletions
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 45f899fc..ccb249ac 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -1,6 +1,7 @@
from modules import extra_networks, shared
import lora
+
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 6f246921..ba1293df 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -4,7 +4,7 @@ import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors
+from modules import shared, devices, sd_models, errors, scripts
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -93,6 +93,7 @@ class LoraOnDisk:
self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
+ self.alias = self.metadata.get('ss_output_name', self.name)
class LoraModule:
@@ -165,8 +166,10 @@ def load_lora(name, filename):
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(sd_module) == torch.nn.Conv2d:
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
@@ -199,11 +202,11 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
- loras_on_disk = [available_loras.get(name, None) for name in names]
+ loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]):
list_available_loras()
- loras_on_disk = [available_loras.get(name, None) for name in names]
+ loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
@@ -232,6 +235,8 @@ def lora_calc_updown(lora, module, target):
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
@@ -240,6 +245,19 @@ def lora_calc_updown(lora, module, target):
return updown
+def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ weights_backup = getattr(self, "lora_weights_backup", None)
+
+ if weights_backup is None:
+ return
+
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
@@ -264,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
self.lora_weights_backup = weights_backup
if current_names != wanted_names:
- if weights_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
+ lora_restore_weights_from_backup(self)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
@@ -300,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
setattr(self, "lora_current_names", wanted_names)
+def lora_forward(module, input, original_forward):
+ """
+ Old way of applying Lora by executing operations during layer's forward.
+ Stacking many loras this way results in big performance degradation.
+ """
+
+ if len(loaded_loras) == 0:
+ return original_forward(module, input)
+
+ input = devices.cond_cast_unet(input)
+
+ lora_restore_weights_from_backup(module)
+ lora_reset_cached_weight(module)
+
+ res = original_forward(module, input)
+
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is None:
+ continue
+
+ module.up.to(device=devices.device)
+ module.down.to(device=devices.device)
+
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ return res
+
+
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
@@ -318,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
def lora_Conv2d_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
@@ -343,24 +392,59 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
def list_available_loras():
available_loras.clear()
+ available_lora_aliases.clear()
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
- candidates = \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
- glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
-
+ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
+ entry = LoraOnDisk(name, filename)
+
+ available_loras[name] = entry
+
+ available_lora_aliases[name] = entry
+ available_lora_aliases[entry.alias] = entry
+
+
+re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+
+
+def infotext_pasted(infotext, params):
+ if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+ return # if the other extension is active, it will handle those fields, no need to do anything
+
+ added = []
+
+ for k, v in params.items():
+ if not k.startswith("AddNet Model "):
+ continue
+
+ num = k[13:]
+
+ if params.get("AddNet Module " + num) != "LoRA":
+ continue
+
+ name = params.get("AddNet Model " + num)
+ if name is None:
+ continue
+
+ m = re_lora_name.match(name)
+ if m:
+ name = m.group(1)
+
+ multiplier = params.get("AddNet Weight A " + num, "1.0")
- available_loras[name] = LoraOnDisk(name, filename)
+ added.append(f"<lora:{name}:{multiplier}>")
+ if added:
+ params["Prompt"] += "\n" + "".join(added)
available_loras = {}
+available_lora_aliases = {}
loaded_loras = []
list_available_loras()
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 3fc38ab9..7db971fd 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,12 +1,12 @@
import torch
import gradio as gr
+from fastapi import FastAPI
import lora
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
-
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
@@ -49,8 +49,33 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
+script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
}))
+
+
+shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
+ "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
+}))
+
+
+def create_lora_json(obj: lora.LoraOnDisk):
+ return {
+ "name": obj.name,
+ "alias": obj.alias,
+ "path": obj.filename,
+ "metadata": obj.metadata,
+ }
+
+
+def api_loras(_: gr.Blocks, app: FastAPI):
+ @app.get("/sdapi/v1/loras")
+ async def get_loras():
+ return [create_lora_json(obj) for obj in lora.available_loras.values()]
+
+
+script_callbacks.on_app_started(api_loras)
+
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 68b11332..a0edbc1e 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -21,7 +21,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename),
- "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
+ "prompt": json.dumps(f"<lora:{lora_on_disk.alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
}