aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/networks.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-12-16 09:58:07 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-12-16 09:58:07 +0300
commitcf2772fab0af5573da775e7437e6acdca424f26e (patch)
tree2ad13a0cf77bc189a8c9097bd507f9674f993da6 /extensions-builtin/Lora/networks.py
parent4afaaf8a020c1df457bcf7250cb1c7f609699fa7 (diff)
parent0dfffe53ec11b2ee097d55efc479f8e707015db9 (diff)
Merge branch 'release_candidate'
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r--extensions-builtin/Lora/networks.py60
1 files changed, 59 insertions, 1 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 96f935b2..629bf853 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -5,16 +5,21 @@ import re
import lora_patches
import network
import network_lora
+import network_glora
import network_hada
import network_ia3
import network_lokr
import network_full
import network_norm
+import network_oft
import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+import modules.textual_inversion.textual_inversion as textual_inversion
+
+from lora_logger import logger
module_types = [
network_lora.ModuleTypeLora(),
@@ -23,6 +28,8 @@ module_types = [
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
+ network_oft.ModuleTypeOFT(),
]
@@ -149,9 +156,20 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
+ bundle_embeddings = {}
for key_network, weight in sd.items():
- key_network_without_network_parts, network_part = key_network.split(".", 1)
+ key_network_without_network_parts, _, network_part = key_network.partition(".")
+
+ if key_network_without_network_parts == "bundle_emb":
+ emb_name, vec_name = network_part.split(".", 1)
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ if vec_name.split('.')[0] == 'string_to_param':
+ _, k2 = vec_name.split('.', 1)
+ emb_dict['string_to_param'] = {k2: weight}
+ else:
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -174,6 +192,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # kohya_ss OFT module
+ elif sd_module is None and "oft_unet" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ # KohakuBlueLeaf OFT module
+ if sd_module is None and "oft_diag" in key:
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
@@ -195,6 +224,14 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
+ embeddings = {}
+ for emb_name, data in bundle_embeddings.items():
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
+ embedding.loaded = None
+ embeddings[emb_name] = embedding
+
+ net.bundle_embeddings = embeddings
+
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -210,11 +247,15 @@ def purge_networks_from_memory():
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -257,6 +298,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+ logger.warning(
+ f'Skip bundle embedding: "{emb_name}"'
+ ' as it was already loaded from embeddings folder'
+ )
+ continue
+
+ embedding.loaded = False
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ embedding.loaded = True
+ emb_db.register_embedding(embedding, shared.sd_model)
+ else:
+ emb_db.skipped_embeddings[name] = embedding
+
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@@ -418,6 +474,7 @@ def network_forward(module, input, original_forward):
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = ()
self.network_weights_backup = None
+ self.network_bias_backup = None
def network_Linear_forward(self, input):
@@ -564,6 +621,7 @@ extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
+loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}