aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruno Seoane <brunoseoaneamarillo@gmail.com>2022-10-23 15:35:49 -0300
committerBruno Seoane <brunoseoaneamarillo@gmail.com>2022-10-23 15:35:49 -0300
commit866b36d705a338d299aba385788729d60f7d48c8 (patch)
treeda4098f9b56fb4d969a8c97d23eb7dab5f8e5737
parente0ca4dfbc10e0af8dfc4185e5e758f33fd2f0d81 (diff)
Move processing's models into models.py
It didn't make sense to have two differente files for the same and "models" is a more descriptive name.
-rw-r--r--modules/api/api.py57
-rw-r--r--modules/api/models.py112
-rw-r--r--modules/api/processing.py106
3 files changed, 119 insertions, 156 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 3acb1f36..20e85e82 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,16 +1,11 @@
-from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers
-import modules.shared as shared
import uvicorn
+from gradio import processing_utils
from fastapi import APIRouter, HTTPException
-import json
-import io
-import base64
+import modules.shared as shared
from modules.api.models import *
-from PIL import Image
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.sd_samplers import all_samplers
from modules.extras import run_extras
-from gradio import processing_utils
def upscaler_to_index(name: str):
try:
@@ -20,29 +15,6 @@ def upscaler_to_index(name: str):
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-# def img_to_base64(img: str):
-# buffer = io.BytesIO()
-# img.save(buffer, format="png")
-# return base64.b64encode(buffer.getvalue())
-
-# def base64_to_bytes(base64Img: str):
-# if "," in base64Img:
-# base64Img = base64Img.split(",")[1]
-# return io.BytesIO(base64.b64decode(base64Img))
-
-# def base64_to_images(base64Imgs: list[str]):
-# imgs = []
-# for img in base64Imgs:
-# img = Image.open(base64_to_bytes(img))
-# imgs.append(img)
-# return imgs
-
-class ImageToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: dict
- info: str
-
-
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -51,15 +23,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
- self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
-
- # def __base64_to_image(self, base64_string):
- # # if has a comma, deal with prefix
- # if "," in base64_string:
- # base64_string = base64_string.split(",")[1]
- # imgdata = base64.b64decode(base64_string)
- # # convert base64 to PIL image
- # return Image.open(io.BytesIO(imgdata))
+ self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -81,7 +45,7 @@ class Api:
b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
- return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info)
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info)
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
@@ -120,10 +84,7 @@ class Api:
processed = process_images(p)
b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
- # for i in processed.images:
- # buffer = io.BytesIO()
- # i.save(buffer, format="png")
- # b64images.append(base64.b64encode(buffer.getvalue()))
+
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info)
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
@@ -134,12 +95,12 @@ class Api:
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
- reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image'])
+ reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="")
- return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2])
+ return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2])
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
upscaler1Index = upscaler_to_index(req.upscaler_1)
diff --git a/modules/api/models.py b/modules/api/models.py
index 209f8af5..362e6277 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,10 +1,118 @@
-from pydantic import BaseModel, Field, Json
+import inspect
+from pydantic import BaseModel, Field, Json, create_model
+from typing import Any, Optional
from typing_extensions import Literal
+from inflection import underscore
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ "do_not_save_samples",
+ "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
+class ModelDef(BaseModel):
+ """Assistance Class for Pydantic Dynamic Model Generation"""
+
+ field: str
+ field_alias: str
+ field_type: Any
+ field_value: Any
+
+
+class PydanticModelGenerator:
+ """
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+ source_data is a snapshot of the default values produced by the class
+ params are the names of the actual keys required by __init__
+ """
+
+ def __init__(
+ self,
+ model_name: str = None,
+ class_instance = None,
+ additional_fields = None,
+ ):
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
+
+ return Optional[field_type]
+
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
+ self._model_name = model_name
+ self._class_data = merge_class_params(class_instance)
+ self._model_def = [
+ ModelDef(
+ field=underscore(k),
+ field_alias=k,
+ field_type=field_type_generator(k, v),
+ field_value=v.default
+ )
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+ ]
+
+ for fields in additional_fields:
+ self._model_def.append(ModelDef(
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
+ field_type=fields["type"],
+ field_value=fields["default"]))
+
+ def generate_model(self):
+ """
+ Creates a pydantic BaseModel
+ from the json and overrides provided at initialization
+ """
+ fields = {
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
+ }
+ DynamicModel = create_model(self._model_name, **fields)
+ DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
+ return DynamicModel
+
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingTxt2Img",
+ StableDiffusionProcessingTxt2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
+).generate_model()
+
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: str
+ parameters: dict
+ info: str
+
+class ImageToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
info: str
class ExtrasBaseRequest(BaseModel):
diff --git a/modules/api/processing.py b/modules/api/processing.py
deleted file mode 100644
index f551fa35..00000000
--- a/modules/api/processing.py
+++ /dev/null
@@ -1,106 +0,0 @@
-from array import array
-from inflection import underscore
-from typing import Any, Dict, Optional
-from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-import inspect
-
-
-API_NOT_ALLOWED = [
- "self",
- "kwargs",
- "sd_model",
- "outpath_samples",
- "outpath_grids",
- "sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
- "extra_generation_params",
- "overlay_images",
- "do_not_reload_embeddings",
- "seed_enable_extras",
- "prompt_for_display",
- "sampler_noise_scheduler_override",
- "ddim_discretize"
-]
-
-class ModelDef(BaseModel):
- """Assistance Class for Pydantic Dynamic Model Generation"""
-
- field: str
- field_alias: str
- field_type: Any
- field_value: Any
-
-
-class PydanticModelGenerator:
- """
- Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
- source_data is a snapshot of the default values produced by the class
- params are the names of the actual keys required by __init__
- """
-
- def __init__(
- self,
- model_name: str = None,
- class_instance = None,
- additional_fields = None,
- ):
- def field_type_generator(k, v):
- # field_type = str if not overrides.get(k) else overrides[k]["type"]
- # print(k, v.annotation, v.default)
- field_type = v.annotation
-
- return Optional[field_type]
-
- def merge_class_params(class_):
- all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
- parameters = {}
- for classes in all_classes:
- parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
- return parameters
-
-
- self._model_name = model_name
- self._class_data = merge_class_params(class_instance)
- self._model_def = [
- ModelDef(
- field=underscore(k),
- field_alias=k,
- field_type=field_type_generator(k, v),
- field_value=v.default
- )
- for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
- ]
-
- for fields in additional_fields:
- self._model_def.append(ModelDef(
- field=underscore(fields["key"]),
- field_alias=fields["key"],
- field_type=fields["type"],
- field_value=fields["default"]))
-
- def generate_model(self):
- """
- Creates a pydantic BaseModel
- from the json and overrides provided at initialization
- """
- fields = {
- d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
- }
- DynamicModel = create_model(self._model_name, **fields)
- DynamicModel.__config__.allow_population_by_field_name = True
- DynamicModel.__config__.allow_mutation = True
- return DynamicModel
-
-StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingTxt2Img",
- StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
-).generate_model()
-
-StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingImg2Img",
- StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
-).generate_model() \ No newline at end of file