aboutsummaryrefslogtreecommitdiff
path: root/modules/api
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api')
-rw-r--r--modules/api/api.py44
-rw-r--r--modules/api/models.py10
2 files changed, 40 insertions, 14 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 1ceba75d..9c670f00 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -100,6 +100,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
@@ -121,7 +122,6 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
@@ -129,15 +129,14 @@ class Api:
)
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
- p = StableDiffusionProcessingTxt2Img(**vars(populate))
- # Override object param
-
- shared.state.begin()
with self.queue_lock:
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+
+ shared.state.begin()
processed = process_images(p)
+ shared.state.end()
- shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -153,7 +152,6 @@ class Api:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
@@ -165,16 +163,14 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
- p = StableDiffusionProcessingImg2Img(**args)
-
- p.init_images = [decode_base64_to_image(x) for x in init_images]
-
- shared.state.begin()
with self.queue_lock:
- processed = process_images(p)
+ p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
- shared.state.end()
+ shared.state.begin()
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -332,6 +328,26 @@ class Api:
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
+ def get_embeddings(self):
+ db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
+ return {
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
+ }
+
def refresh_checkpoints(self):
shared.refresh_checkpoints()
diff --git a/modules/api/models.py b/modules/api/models.py
index c446ce7a..4a632c68 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
+class EmbeddingsResponse(BaseModel):
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file