aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
authorBilly Cao <aliencaocao@gmail.com>2022-11-27 12:39:31 +0800
committerGitHub <noreply@github.com>2022-11-27 12:39:31 +0800
commit349f0461ecd5f2522179e644213e2c65da34fc7c (patch)
treef71359be2e68f7dc311b0126d0af669a29f92a81 /modules/textual_inversion
parentadb6cb7619989cbc7a271cc6c2ae27bb936c43d9 (diff)
parentb5050ad2071644f7b4c99660dc66a8a95136102f (diff)
Merge branch 'master' into support_any_resolution
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/textual_inversion.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5e4d8688..a273e663 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -64,7 +64,8 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
- ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+ # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
+ ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
@@ -155,13 +156,11 @@ class EmbeddingDatabase:
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
- embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):