aboutsummaryrefslogtreecommitdiff
path: root/modules/api
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api')
-rw-r--r--modules/api/api.py43
1 files changed, 13 insertions, 30 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 5c5b210f..6c06d449 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
-# copy from wrap_gradio_gpu_call of webui.py
-# because queue lock will be acquired in api handlers
-# and time start needs to be set
-# the function has been modified into two parts
-
-def before_gpu_call():
- devices.torch_gc()
-
- shared.state.sampling_step = 0
- shared.state.job_count = -1
- shared.state.job_no = 0
- shared.state.job_timestamp = shared.state.get_job_timestamp()
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.current_image_sampling_step = 0
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.textinfo = None
- shared.state.time_start = time.time()
-
-def after_gpu_call():
- shared.state.job = ""
- shared.state.job_count = 0
-
- devices.torch_gc()
def upscaler_to_index(name: str):
try:
@@ -41,8 +16,10 @@ def upscaler_to_index(name: str):
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
@@ -51,6 +28,7 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2')
return reqDict
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -78,10 +56,13 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -119,11 +100,13 @@ class Api:
imgs = [img] * p.batch_size
p.init_images = imgs
- # Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))