aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/networks.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-10-14 12:18:55 +0300
committerGitHub <noreply@github.com>2023-10-14 12:18:55 +0300
commit4be7b620c2994a14c0ce879d3a6cea0db0f66a84 (patch)
tree3c1ca88bf58264fdc8b50dc6738cc5b7720cb2de /extensions-builtin/Lora/networks.py
parent19f5795c27f13c84a922bc69f559cc69e408b3e2 (diff)
parenta8cbe50c9fa324ed887089e4333452ecc4355c92 (diff)
Merge pull request #13568 from AUTOMATIC1111/lora_emb_bundle
Add lora-embedding bundle system
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r--extensions-builtin/Lora/networks.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index ddab3c55..60d8dec4 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -16,6 +16,9 @@ import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+import modules.textual_inversion.textual_inversion as textual_inversion
+
+from lora_logger import logger
module_types = [
network_lora.ModuleTypeLora(),
@@ -151,9 +154,19 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
+ bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
+ if key_network_without_network_parts == "bundle_emb":
+ emb_name, vec_name = network_part.split(".", 1)
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ if vec_name.split('.')[0] == 'string_to_param':
+ _, k2 = vec_name.split('.', 1)
+ emb_dict['string_to_param'] = {k2: weight}
+ else:
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -197,6 +210,14 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
+ embeddings = {}
+ for emb_name, data in bundle_embeddings.items():
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
+ embedding.loaded = None
+ embeddings[emb_name] = embedding
+
+ net.bundle_embeddings = embeddings
+
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -212,11 +233,15 @@ def purge_networks_from_memory():
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -259,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+ logger.warning(
+ f'Skip bundle embedding: "{emb_name}"'
+ ' as it was already loaded from embeddings folder'
+ )
+ continue
+
+ embedding.loaded = False
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ embedding.loaded = True
+ emb_db.register_embedding(embedding, shared.sd_model)
+ else:
+ emb_db.skipped_embeddings[name] = embedding
+
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@@ -567,6 +607,7 @@ extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
+loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}