aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/textual_inversion.py29
-rw-r--r--modules/textual_inversion/ui.py4
2 files changed, 21 insertions, 12 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c0baaace..1183aab7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,7 +7,7 @@ import tqdm
import html
import datetime
-from modules import shared, devices, sd_hijack, processing
+from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -17,6 +17,8 @@ class Embedding:
self.name = name
self.step = step
self.cached_checksum = None
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
def save(self, filename):
embedding_data = {
@@ -24,6 +26,8 @@ class Embedding:
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
+ "sd_checkpoint": self.sd_checkpoint,
+ "sd_checkpoint_name": self.sd_checkpoint_name,
}
torch.save(embedding_data, filename)
@@ -41,6 +45,7 @@ class Embedding:
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
+
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
@@ -57,7 +62,8 @@ class EmbeddingDatabase:
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
- self.ids_lookup[first_id].append((ids, embedding))
+
+ self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
@@ -95,6 +101,8 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
@@ -117,24 +125,21 @@ class EmbeddingDatabase:
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
- return None
+ return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
- return embedding
-
- return None
+ return embedding, len(ids)
+ return None, None
-def create_embedding(name, num_vectors_per_token):
- init_text = '*'
-
+def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
@@ -251,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
+ checkpoint = sd_models.select_checkpoint()
+
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
embedding.save(filename)
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index ce3677a9..66c43ffb 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared
-def create_embedding(name, nvpt):
- filename = ti.create_embedding(name, nvpt)
+def create_embedding(name, initialization_text, nvpt):
+ filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()