aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/lora_logger.py33
-rw-r--r--extensions-builtin/Lora/networks.py80
2 files changed, 81 insertions, 32 deletions
diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py
new file mode 100644
index 00000000..d50e90f0
--- /dev/null
+++ b/extensions-builtin/Lora/lora_logger.py
@@ -0,0 +1,33 @@
+import sys
+import copy
+import logging
+
+
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ "DEBUG": "\033[0;36m", # CYAN
+ "INFO": "\033[0;32m", # GREEN
+ "WARNING": "\033[0;33m", # YELLOW
+ "ERROR": "\033[0;31m", # RED
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
+ "RESET": "\033[0m", # RESET COLOR
+ }
+
+ def format(self, record):
+ colored_record = copy.copy(record)
+ levelname = colored_record.levelname
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+ return super().format(colored_record)
+
+
+logger = logging.getLogger("lora")
+logger.propagate = False
+
+
+if not logger.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(
+ ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
+ )
+ logger.addHandler(handler) \ No newline at end of file
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 465e24c8..12f70576 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -17,6 +17,8 @@ from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding
+from lora_logger import logger
+
module_types = [
network_lora.ModuleTypeLora(),
network_hada.ModuleTypeHada(),
@@ -206,7 +208,40 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
- net.bundle_embeddings = bundle_embeddings
+ embeddings = {}
+ for emb_name, data in bundle_embeddings.items():
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
+ vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+ vectors = data['clip_g'].shape[0]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ else:
+ raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
+
+ embedding = Embedding(vec, emb_name)
+ embedding.vectors = vectors
+ embedding.shape = shape
+ embedding.loaded = None
+ embeddings[emb_name] = embedding
+
+ net.bundle_embeddings = embeddings
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -229,8 +264,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
- for emb_name in net.bundle_embeddings:
- emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
- for emb_name, data in net.bundle_embeddings.items():
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
- vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
- shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
- vectors = data['clip_g'].shape[0]
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- else:
- raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
-
- embedding = Embedding(vec, emb_name)
- embedding.vectors = vectors
- embedding.shape = shape
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+ logger.warning(
+ f'Skip bundle embedding: "{emb_name}"'
+ ' as it was already loaded from embeddings folder'
+ )
+ continue
+ embedding.loaded = False
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ embedding.loaded = True
emb_db.register_embedding(embedding, shared.sd_model)
else:
emb_db.skipped_embeddings[name] = embedding