From 5387576c59632f6bc85c81de4d19219f63437d0a Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 15 Mar 2023 15:11:04 -0400 Subject: api error handler --- modules/api/api.py | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) (limited to 'modules/api') diff --git a/modules/api/api.py b/modules/api/api.py index 35e17afc..8d4ff4e0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,8 +6,11 @@ import uvicorn from threading import Lock from io import BytesIO from gradio.processing_utils import decode_base64_to_file -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response +from fastapi import APIRouter, Depends, FastAPI, Request, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.exceptions import HTTPException +from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared @@ -90,6 +93,16 @@ def encode_pil_to_base64(image): return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): + rich_available = True + try: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + console = Console() + except: + import traceback + rich_available = False + @app.middleware("http") async def log_and_time(req: Request, call_next): ts = time.time() @@ -110,6 +123,36 @@ def api_middleware(app: FastAPI): )) return res + def handle_exception(request: Request, e: Exception): + err = { + "error": type(e).__name__, + "detail": vars(e).get('detail', ''), + "body": vars(e).get('body', ''), + "errors": str(e), + } + print(f"API error: {request.method}: {request.url} {err}") + if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if rich_available: + console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) + else: + traceback.print_exc() + return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) + + @app.middleware("http") + async def exception_handling(request: Request, call_next): + try: + return await call_next(request) + except Exception as e: + return handle_exception(request, e) + + @app.exception_handler(Exception) + async def fastapi_exception_handler(request: Request, e: Exception): + return handle_exception(request, e) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, e: HTTPException): + return handle_exception(request, e) + class Api: def __init__(self, app: FastAPI, queue_lock: Lock): -- cgit v1.2.1 From 4cbbb881ee530d9b9ba18027e2b0057e6a2c4ee1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com> Date: Thu, 9 Mar 2023 07:56:19 +0300 Subject: Unload checkpoints on Request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …to free VRAM. New Action buttons in the settings to manually free and reload checkpoints, essentially juggling models between RAM and VRAM. --- modules/api/api.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'modules/api') diff --git a/modules/api/api.py b/modules/api/api.py index 35e17afc..f52f7fef 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -150,6 +150,8 @@ class Api: self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) + self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) def add_api_route(self, path: str, endpoint, **kwargs): @@ -412,6 +414,16 @@ class Api: return {} + def unloadapi(self): + unload_model_weights() + + return {} + + def reloadapi(self): + reload_model_weights() + + return {} + def skip(self): shared.state.skip() -- cgit v1.2.1