aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r--extensions-builtin/Lora/lora.py61
1 files changed, 46 insertions, 15 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 34ff57dd..9cdff6ed 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -3,7 +3,7 @@ import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
+ if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
+ if 'mlp_fc1' in m[1]:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
+ elif 'mlp_fc2' in m[1]:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
+ else:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
+
return key
@@ -78,9 +86,15 @@ class LoraOnDisk:
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+ def read_metadata():
+ metadata = sd_models.read_metadata_from_safetensors(filename)
+ metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
+
+ return metadata
+
if self.is_safetensors:
try:
- self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
@@ -91,7 +105,6 @@ class LoraOnDisk:
self.metadata = m
- self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
@@ -142,10 +155,20 @@ class LoraUpDownModule:
def assign_lora_names_to_compvis_modules(sd_model):
lora_layer_mapping = {}
- for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
- lora_name = name.replace(".", "_")
- lora_layer_mapping[lora_name] = module
- module.lora_layer_name = lora_name
+ if shared.sd_model.is_sdxl:
+ for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
+ if not hasattr(embedder, 'wrapped'):
+ continue
+
+ for name, module in embedder.wrapped.named_modules():
+ lora_name = f'{i}_{name.replace(".", "_")}'
+ lora_layer_mapping[lora_name] = module
+ module.lora_layer_name = lora_name
+ else:
+ for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+ lora_name = name.replace(".", "_")
+ lora_layer_mapping[lora_name] = module
+ module.lora_layer_name = lora_name
for name, module in shared.sd_model.model.named_modules():
lora_name = name.replace(".", "_")
@@ -168,10 +191,10 @@ def load_lora(name, lora_on_disk):
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
- for key_diffusers, weight in sd.items():
- key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
- key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
+ for key_lora, weight in sd.items():
+ key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)
+ key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None:
@@ -179,8 +202,16 @@ def load_lora(name, lora_on_disk):
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
+ # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
+ if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
+ key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
+ sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
+ elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
+ key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
+
if sd_module is None:
- keys_failed_to_match[key_diffusers] = key
+ keys_failed_to_match[key_lora] = key
continue
lora_module = lora.modules.get(key, None)
@@ -203,9 +234,9 @@ def load_lora(name, lora_on_disk):
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
- print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
+ print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
+ raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -217,7 +248,7 @@ def load_lora(name, lora_on_disk):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
+ raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
@@ -443,7 +474,7 @@ def list_available_loras():
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
- for filename in sorted(candidates, key=str.lower):
+ for filename in candidates:
if os.path.isdir(filename):
continue