aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/api/api.py45
-rw-r--r--requirements.txt1
2 files changed, 45 insertions, 1 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 35e17afc..8d4ff4e0 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -6,8 +6,11 @@ import uvicorn
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
-from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
+from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from fastapi.exceptions import HTTPException
+from fastapi.responses import JSONResponse
+from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
@@ -90,6 +93,16 @@ def encode_pil_to_base64(image):
return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
+ rich_available = True
+ try:
+ import anyio # importing just so it can be placed on silent list
+ import starlette # importing just so it can be placed on silent list
+ from rich.console import Console
+ console = Console()
+ except:
+ import traceback
+ rich_available = False
+
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
@@ -110,6 +123,36 @@ def api_middleware(app: FastAPI):
))
return res
+ def handle_exception(request: Request, e: Exception):
+ err = {
+ "error": type(e).__name__,
+ "detail": vars(e).get('detail', ''),
+ "body": vars(e).get('body', ''),
+ "errors": str(e),
+ }
+ print(f"API error: {request.method}: {request.url} {err}")
+ if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
+ if rich_available:
+ console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
+ else:
+ traceback.print_exc()
+ return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
+
+ @app.middleware("http")
+ async def exception_handling(request: Request, call_next):
+ try:
+ return await call_next(request)
+ except Exception as e:
+ return handle_exception(request, e)
+
+ @app.exception_handler(Exception)
+ async def fastapi_exception_handler(request: Request, e: Exception):
+ return handle_exception(request, e)
+
+ @app.exception_handler(HTTPException)
+ async def http_exception_handler(request: Request, e: HTTPException):
+ return handle_exception(request, e)
+
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
diff --git a/requirements.txt b/requirements.txt
index 6d53f089..69c1e7cc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -30,3 +30,4 @@ GitPython
torchsde
safetensors
psutil
+rich