aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/api/api.py17
-rw-r--r--modules/api/models.py11
-rw-r--r--modules/textual_inversion/textual_inversion.py8
3 files changed, 28 insertions, 8 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 30bf3dac..9c670f00 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -330,9 +330,22 @@ class Api:
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
return {
- "loaded": sorted(db.word_embeddings.keys()),
- "skipped": sorted(db.skipped_embeddings),
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self):
diff --git a/modules/api/models.py b/modules/api/models.py
index a8472dc9..4a632c68 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,6 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
class EmbeddingsResponse(BaseModel):
- loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1e5722e7..fd253477 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -59,7 +59,7 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
- self.skipped_embeddings = []
+ self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1
@@ -91,7 +91,7 @@ class EmbeddingDatabase:
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
- self.skipped_embeddings = []
+ self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
@@ -136,7 +136,7 @@ class EmbeddingDatabase:
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
- self.skipped_embeddings.append(name)
+ self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
- print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]