aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/api/api.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 3df6ff96..3caa83a4 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -33,6 +33,14 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ def __base64_to_image(self, base64_string):
+ # if has a comma, deal with prefix
+ if "," in base64_string:
+ base64_string = base64_string.split(",")[1]
+ imgdata = base64.b64decode(base64_string)
+ # convert base64 to PIL image
+ return Image.open(io.BytesIO(imgdata))
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -74,26 +82,22 @@ class Api:
mask = img2imgreq.mask
if mask:
- raise HTTPException(status_code=400, detail="Mask not supported yet")
+ mask = self.__base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
- "do_not_save_grid": True
+ "do_not_save_grid": True,
+ "mask": mask
}
)
p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = []
for img in init_images:
- # if has a comma, deal with prefix
- if "," in img:
- img = img.split(",")[1]
- # convert base64 to PIL image
- img = base64.b64decode(img)
- img = Image.open(io.BytesIO(img))
+ img = self.__base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs