aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py115
-rw-r--r--modules/api/models.py27
-rw-r--r--modules/shared.py13
3 files changed, 129 insertions, 26 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 49c213ea..9d68ac23 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,12 +1,70 @@
+# import time
+
+# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+# from modules.sd_samplers import all_samplers
+# from modules.extras import run_pnginfo
+# import modules.shared as shared
+# from modules import devices
+# import uvicorn
+# from fastapi import Body, APIRouter, HTTPException
+# from fastapi.responses import JSONResponse
+# from pydantic import BaseModel, Field, Json
+# from typing import List
+# import json
+# import io
+# import base64
+# from PIL import Image
+
+# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
+# class TextToImageResponse(BaseModel):
+# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+# parameters: Json
+# info: Json
+
+# class ImageToImageResponse(BaseModel):
+# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+# parameters: Json
+# info: Json
+
+import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException
import modules.shared as shared
+from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras
+# copy from wrap_gradio_gpu_call of webui.py
+# because queue lock will be acquired in api handlers
+# and time start needs to be set
+# the function has been modified into two parts
+
+def before_gpu_call():
+ devices.torch_gc()
+
+ shared.state.sampling_step = 0
+ shared.state.job_count = -1
+ shared.state.job_no = 0
+ shared.state.job_timestamp = shared.state.get_job_timestamp()
+ shared.state.current_latent = None
+ shared.state.current_image = None
+ shared.state.current_image_sampling_step = 0
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.textinfo = None
+ shared.state.time_start = time.time()
+
+def after_gpu_call():
+ shared.state.job = ""
+ shared.state.job_count = 0
+
+ devices.torch_gc()
+
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
@@ -32,15 +90,16 @@ class Api:
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
@@ -48,34 +107,36 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
+ before_gpu_call()
with self.queue_lock:
processed = process_images(p)
-
+ after_gpu_call()
+
b64images = list(map(encode_pil_to_base64, processed.images))
-
+
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
+ raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
- raise HTTPException(status_code=404, detail="Init image not found")
+ raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
-
+
populate = img2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
- "do_not_save_grid": True,
+ "do_not_save_grid": True,
"mask": mask
}
)
@@ -88,15 +149,17 @@ class Api:
p.init_images = imgs
# Override object param
+ before_gpu_call()
with self.queue_lock:
processed = process_images(p)
-
+ after_gpu_call()
+
b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
-
+
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
@@ -124,7 +187,29 @@ class Api:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
-
+
+ def progressapi(self):
+ # copy from check_progress_call of ui.py
+
+ if shared.state.job_count == 0:
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
+
+ # avoid dividing zero
+ progress = 0.01
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ eta_relative = eta-time_since_start
+
+ progress = min(progress, 1)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
+
def pnginfoapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index dd122321..c374a627 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -51,17 +51,17 @@ class PydanticModelGenerator:
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
-
+
return Optional[field_type]
-
+
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
-
-
+
+
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
@@ -73,11 +73,11 @@ class PydanticModelGenerator:
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
-
+
for fields in additional_fields:
self._model_def.append(ModelDef(
- field=underscore(fields["key"]),
- field_alias=fields["key"],
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
@@ -94,15 +94,15 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel
-
+
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingTxt2Img",
+ "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingImg2Img",
+ "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()
@@ -148,4 +148,9 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class ProgressResponse(BaseModel):
+ progress: float
+ eta_relative: float
+ state: dict
diff --git a/modules/shared.py b/modules/shared.py
index fb84afd8..0f4c035d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -147,6 +147,19 @@ class State:
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
+ def js(self):
+ obj = {
+ "skipped": self.skipped,
+ "interrupted": self.skipped,
+ "job": self.job,
+ "job_count": self.job_count,
+ "job_no": self.job_no,
+ "sampling_step": self.sampling_step,
+ "sampling_steps": self.sampling_steps,
+ }
+
+ return json.dumps(obj)
+
state = State()