aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-10-09 22:52:09 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-10-09 22:52:09 +0800
commit2aa485b5afb13fd6aab79777e4dfc488591b2f1c (patch)
treef2b5e5eda241fbaad8169aa0cfd79be11acb5660
parent7d60076b8b275771a1aa98f017aff845ef68d964 (diff)
add lora bundle system
-rw-r--r--extensions-builtin/Lora/network.py1
-rw-r--r--extensions-builtin/Lora/networks.py48
2 files changed, 49 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index d8e8dfb7..6021fd8d 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -93,6 +93,7 @@ class Network: # LoraModule
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
+ self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 315682b3..652b8ebe 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -15,6 +15,7 @@ import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+from modules.textual_inversion.textual_inversion import Embedding
module_types = [
network_lora.ModuleTypeLora(),
@@ -149,9 +150,15 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
+ bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
+ if key_network_without_network_parts == "bundle_emb":
+ emb_name, vec_name = network_part.split(".", 1)
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -195,6 +202,8 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
+ net.bundle_embeddings = bundle_embeddings
+
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -210,11 +219,14 @@ def purge_networks_from_memory():
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
+ for emb_name in net.bundle_embeddings:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -257,6 +269,41 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
+ for emb_name, data in net.bundle_embeddings.items():
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
+ vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+ vectors = data['clip_g'].shape[0]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ else:
+ raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
+
+ embedding = Embedding(vec, emb_name)
+ embedding.vectors = vectors
+ embedding.shape = shape
+
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ emb_db.register_embedding(embedding, shared.sd_model)
+ else:
+ emb_db.skipped_embeddings[name] = embedding
+
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@@ -565,6 +612,7 @@ extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
+loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}