aboutsummaryrefslogtreecommitdiff
path: root/modules/api
diff options
context:
space:
mode:
authorarcticfaded <jbelt021@fiu.edu>2022-10-18 06:51:53 +0000
committerarcticfaded <jbelt021@fiu.edu>2022-10-18 06:51:53 +0000
commit8d5d863a9d11850464fdb6b64f34602803c15ccc (patch)
treec983dbe973ddefc84ad3d2d8222de1c999442e46 /modules/api
parent1df3ff25e6fe2e3f308e45f7a6dd37fb4f1988e6 (diff)
gradio and FastAPI
Diffstat (limited to 'modules/api')
-rw-r--r--modules/api/api.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 14613d8c..ce98cb8c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel):
class Api:
- def __init__(self, app):
+ def __init__(self, app, queue_lock):
self.router = APIRouter()
- app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app = app
+ self.queue_lock = queue_lock
+ self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -30,7 +32,8 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
- processed = process_images(p)
+ with self.queue_lock:
+ processed = process_images(p)
b64images = []
for i in processed.images:
@@ -52,5 +55,5 @@ class Api:
raise NotImplementedError
def launch(self, server_name, port):
- app.include_router(self.router)
- uvicorn.run(app, host=server_name, port=port)
+ self.app.include_router(self.router)
+ uvicorn.run(self.app, host=server_name, port=port)