aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora')
-rw-r--r--extensions-builtin/Lora/networks.py31
1 files changed, 2 insertions, 29 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 12f70576..d5f0f9f1 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -15,7 +15,7 @@ import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
-from modules.textual_inversion.textual_inversion import Embedding
+import modules.textual_inversion.textual_inversion as textual_inversion
from lora_logger import logger
@@ -210,34 +210,7 @@ def load_network(name, network_on_disk):
embeddings = {}
for emb_name, data in bundle_embeddings.items():
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
- vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
- shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
- vectors = data['clip_g'].shape[0]
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- else:
- raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
-
- embedding = Embedding(vec, emb_name)
- embedding.vectors = vectors
- embedding.shape = shape
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
embedding.loaded = None
embeddings[emb_name] = embedding