aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/api/api.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 648bd6a8..efcedbba 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -4,7 +4,7 @@ import time
import uvicorn
from threading import Lock
from io import BytesIO
-from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
@@ -41,6 +41,10 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2')
return reqDict
+def decode_base64_to_image(encoding):
+ if encoding.startswith("data:image/"):
+ encoding = encoding.split(";")[1].split(",")[1]
+ return Image.open(BytesIO(base64.b64decode(encoding)))
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
@@ -134,10 +138,7 @@ class Api:
mask = img2imgreq.mask
if mask:
- if mask.startswith("data:image/"):
- mask = decode_base64_to_image(mask)
- else:
- mask = Image.open(BytesIO(base64.b64decode(mask)))
+ mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
@@ -151,10 +152,7 @@ class Api:
imgs = []
for img in init_images:
- if img.startswith("data:image/"):
- img = decode_base64_to_image(img)
- else:
- img = Image.open(BytesIO(base64.b64decode(img)))
+ img = decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs