aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-05-31 19:29:47 +0300
committerGitHub <noreply@github.com>2023-05-31 19:29:47 +0300
commit881de0df38c1fa6d0d61f7bc6fc93c100a9f35d0 (patch)
tree7079024b0ffc02744caf1083a46268bb3cc941d2
parent670195d7202d49bec0c3d489cd7bbaac9dcd5901 (diff)
parent4635f31270d1b5d41ad63815cb400b1ca73ea859 (diff)
Merge pull request #10803 from klimaleksus/refactoring-for-embedding-merge
Refactor EmbeddingDatabase.register_embedding() to allow unregistering
-rw-r--r--modules/textual_inversion/textual_inversion.py25
1 files changed, 19 insertions, 6 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index a040a988..b3dcb140 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -119,16 +119,29 @@ class EmbeddingDatabase:
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ return self.register_embedding_by_name(embedding, model, embedding.name)
+ def register_embedding_by_name(self, embedding, model, name):
+ ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+ if name in self.word_embeddings:
+ # remove old one from the lookup list
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+ else:
+ lookup = self.ids_lookup[first_id]
+ if embedding is not None:
+ lookup += [(ids, embedding)]
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+ if embedding is None:
+ # unregister embedding with specified name
+ if name in self.word_embeddings:
+ del self.word_embeddings[name]
+ if len(self.ids_lookup[first_id])==0:
+ del self.ids_lookup[first_id]
+ return None
+ self.word_embeddings[name] = embedding
return embedding
def get_expected_shape(self):