aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/sd_hijack.py8
-rw-r--r--modules/textual_inversion/textual_inversion.py13
-rw-r--r--modules/textual_inversion/ui.py4
-rw-r--r--modules/ui.py2
4 files changed, 13 insertions, 14 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fd57e5c5..3fa06242 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if embedding is None:
remade_tokens.append(token)
@@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
- i += emb_len
+ i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
@@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
- i += emb_len
+ i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c0baaace..0c50161d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -117,24 +117,21 @@ class EmbeddingDatabase:
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
- return None
+ return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
- return embedding
+ return embedding, len(ids)
- return None
+ return None, None
-
-def create_embedding(name, num_vectors_per_token):
- init_text = '*'
-
+def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index ce3677a9..66c43ffb 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared
-def create_embedding(name, nvpt):
- filename = ti.create_embedding(name, nvpt)
+def create_embedding(name, initialization_text, nvpt):
+ filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
diff --git a/modules/ui.py b/modules/ui.py
index 3b81a4f7..eca50df0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
@@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
+ initialization_text,
nvpt,
],
outputs=[