aboutsummaryrefslogtreecommitdiff
path: root/modules/api/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api/models.py')
-rw-r--r--modules/api/models.py52
1 files changed, 31 insertions, 21 deletions
diff --git a/modules/api/models.py b/modules/api/models.py
index 800c9b93..16edf11c 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,12 +1,10 @@
import inspect
from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
-from typing_extensions import Literal
+from typing import Any, Optional, Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser
-from typing import Dict, List
API_NOT_ALLOWED = [
"self",
@@ -50,10 +48,12 @@ class PydanticModelGenerator:
additional_fields = None,
):
def field_type_generator(k, v):
- # field_type = str if not overrides.get(k) else overrides[k]["type"]
- # print(k, v.annotation, v.default)
field_type = v.annotation
+ if field_type == 'Image':
+ # images are sent as base64 strings via API
+ field_type = 'str'
+
return Optional[field_type]
def merge_class_params(class_):
@@ -63,7 +63,6 @@ class PydanticModelGenerator:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
-
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
@@ -72,7 +71,7 @@ class PydanticModelGenerator:
field=underscore(k),
field_alias=k,
field_type=field_type_generator(k, v),
- field_value=v.default
+ field_value=None if isinstance(v.default, property) else v.default
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
@@ -108,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
@@ -125,16 +126,18 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
class TextToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@@ -167,17 +170,18 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
- items: dict = Field(title="Items", description="An object containing all the info the image had")
+ items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
+ parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
@@ -202,9 +206,6 @@ class TrainResponse(BaseModel):
class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
-class PreprocessResponse(BaseModel):
- info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
-
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
@@ -231,8 +232,8 @@ FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
- aliases: List[str] = Field(title="Aliases")
- options: Dict[str, str] = Field(title="Options")
+ aliases: list[str] = Field(title="Aliases")
+ options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
@@ -283,8 +284,8 @@ class EmbeddingItem(BaseModel):
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
- loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+ loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
@@ -302,11 +303,20 @@ class ScriptArg(BaseModel):
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
- choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+ choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
- args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+ args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+
+class ExtensionItem(BaseModel):
+ name: str = Field(title="Name", description="Extension name")
+ remote: str = Field(title="Remote", description="Extension Repository URL")
+ branch: str = Field(title="Branch", description="Extension Repository Branch")
+ commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
+ version: str = Field(title="Version", description="Extension Version")
+ commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
+ enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")