aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorarcticfaded <jbelt021@fiu.edu>2022-10-18 06:51:53 +0000
committerarcticfaded <jbelt021@fiu.edu>2022-10-18 06:51:53 +0000
commit8d5d863a9d11850464fdb6b64f34602803c15ccc (patch)
treec983dbe973ddefc84ad3d2d8222de1c999442e46
parent1df3ff25e6fe2e3f308e45f7a6dd37fb4f1988e6 (diff)
gradio and FastAPI
-rw-r--r--modules/api/api.py13
-rw-r--r--webui.py16
2 files changed, 15 insertions, 14 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 14613d8c..ce98cb8c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel):
class Api:
- def __init__(self, app):
+ def __init__(self, app, queue_lock):
self.router = APIRouter()
- app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app = app
+ self.queue_lock = queue_lock
+ self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -30,7 +32,8 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
- processed = process_images(p)
+ with self.queue_lock:
+ processed = process_images(p)
b64images = []
for i in processed.images:
@@ -52,5 +55,5 @@ class Api:
raise NotImplementedError
def launch(self, server_name, port):
- app.include_router(self.router)
- uvicorn.run(app, host=server_name, port=port)
+ self.app.include_router(self.router)
+ uvicorn.run(self.app, host=server_name, port=port)
diff --git a/webui.py b/webui.py
index 6212be01..71724c3b 100644
--- a/webui.py
+++ b/webui.py
@@ -4,7 +4,7 @@ import time
import importlib
import signal
import threading
-
+from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path
@@ -31,7 +31,6 @@ from modules.paths import script_path
from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
-
queue_lock = threading.Lock()
@@ -97,7 +96,7 @@ def initialize():
def create_api(app):
from modules.api.api import Api
- api = Api(app)
+ api = Api(app, queue_lock)
return api
def wait_on_server(demo=None):
@@ -141,7 +140,7 @@ def webui(launch_api=False):
create_api(app)
wait_on_server(demo)
-
+
sd_samplers.set_samplers()
print('Reloading Custom Scripts')
@@ -153,11 +152,10 @@ def webui(launch_api=False):
print('Restarting Gradio')
+
+task = []
if __name__ == "__main__":
- if not cmd_opts.nowebui:
+ if cmd_opts.nowebui:
api_only()
-
- if cmd_opts.api:
- webui(True)
else:
- webui(False)
+ webui(cmd_opts.api) \ No newline at end of file