From bdc90837987ed8919dd611fd01553b0c170ded5c Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Thu, 27 Oct 2022 15:20:15 -0400 Subject: Add a barebones interrogate API --- modules/api/api.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) (limited to 'modules/api/api.py') diff --git a/modules/api/api.py b/modules/api/api.py index 6e9d6097..eabdb7b8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,4 +1,4 @@ -from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI +from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI, InterrogateAPI from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_pnginfo @@ -25,6 +25,11 @@ class ImageToImageResponse(BaseModel): parameters: Json info: Json +class InterrogateResponse(BaseModel): + caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + parameters: Json + info: Json + class Api: def __init__(self, app, queue_lock): @@ -33,6 +38,7 @@ class Api: self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) def __base64_to_image(self, base64_string): # if has a comma, deal with prefix @@ -118,6 +124,23 @@ class Api: return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) + def interrogateapi(self, interrogatereq: InterrogateAPI): + image_b64 = interrogatereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + populate = interrogatereq.copy(update={ # Override __init__ params + } + ) + + img = self.__base64_to_image(image_b64) + + # Override object param + with self.queue_lock: + processed = shared.interrogator.interrogate(img) + + return InterrogateResponse(caption=processed, parameters=json.dumps(vars(interrogatereq)), info=None) + def extrasapi(self): raise NotImplementedError -- cgit v1.2.1