aboutsummaryrefslogtreecommitdiff
path: root/modules/api
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api')
-rw-r--r--modules/api/api.py26
1 files changed, 9 insertions, 17 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 596a6616..0eccccbb 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -6,9 +6,9 @@ from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
+from modules import sd_samplers
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
@@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
-sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+def validate_sampler_name(name):
+ config = sd_samplers.all_samplers_map.get(name, None)
+ if config is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+ return name
def setUpscalers(req: dict):
reqDict = vars(req)
@@ -82,14 +86,9 @@ class Api:
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
- sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-
- if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
- "sampler_index": sampler_index[0],
+ "sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
@@ -109,12 +108,6 @@ class Api:
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
- sampler_index = sampler_to_index(img2imgreq.sampler_index)
-
- if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
-
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -123,10 +116,9 @@ class Api:
if mask:
mask = decode_base64_to_image(mask)
-
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
- "sampler_index": sampler_index[0],
+ "sampler_name": validate_sampler_name(img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
@@ -272,7 +264,7 @@ class Api:
return vars(shared.cmd_opts)
def get_samplers(self):
- return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
+ return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
upscalers = []