aboutsummaryrefslogtreecommitdiff
path: root/modules/api
diff options
context:
space:
mode:
authorCodeHatchling <steve@codehatch.com>2023-12-02 21:14:02 -0700
committerCodeHatchling <steve@codehatch.com>2023-12-02 21:14:02 -0700
commit3bd3a091604a332de6ff249870dabd2a91215499 (patch)
tree0323625627748ee44fc192bb2496585a4db56b5a /modules/api
parentbb04d400c95df01d191ef6c1a43e66b95425fa33 (diff)
parentf0f100e67b78f686dc73cf3c8cad422e45cc9b8a (diff)
Merge remote-tracking branch 'origin/dev' into soft-inpainting
# Conflicts: # modules/processing.py
Diffstat (limited to 'modules/api')
-rw-r--r--modules/api/api.py66
-rw-r--r--modules/api/models.py36
2 files changed, 64 insertions, 38 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index e6edffe7..09083874 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -17,19 +17,18 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
-from PIL import PngImagePlugin,Image
-from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
+from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
-from typing import Dict, List, Any
+from typing import Any
import piexif
import piexif.helper
from contextlib import closing
@@ -103,7 +102,8 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
-
+ if isinstance(image, str):
+ return image
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
@@ -221,15 +221,15 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
- self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
- self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
- self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
- self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
- self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
- self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
- self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
- self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
+ self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
@@ -242,7 +242,8 @@ class Api:
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
- self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+ self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
+ self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@@ -473,9 +474,6 @@ class Api:
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest):
- if(not req.image.strip()):
- return models.PNGInfoResponse(info="")
-
image = decode_base64_to_image(req.image.strip())
if image is None:
return models.PNGInfoResponse(info="")
@@ -484,9 +482,10 @@ class Api:
if geninfo is None:
geninfo = ""
- items = {**{'parameters': geninfo}, **items}
+ params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
+ script_callbacks.infotext_pasted_callback(geninfo, params)
- return models.PNGInfoResponse(info=geninfo, items=items)
+ return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
@@ -541,12 +540,12 @@ class Api:
return {}
def unloadapi(self):
- unload_model_weights()
+ sd_models.unload_model_weights()
return {}
def reloadapi(self):
- reload_model_weights()
+ sd_models.send_model_to_device(shared.sd_model)
return {}
@@ -564,9 +563,9 @@ class Api:
return options
- def set_config(self, req: Dict[str, Any]):
+ def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
- if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+ if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():
@@ -770,6 +769,25 @@ class Api:
cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda)
+ def get_extensions_list(self):
+ from modules import extensions
+ extensions.list_extensions()
+ ext_list = []
+ for ext in extensions.extensions:
+ ext: extensions.Extension
+ ext.read_info_from_repo()
+ if ext.remote is not None:
+ ext_list.append({
+ "name": ext.name,
+ "remote": ext.remote,
+ "branch": ext.branch,
+ "commit_hash":ext.commit_hash,
+ "commit_date":ext.commit_date,
+ "version":ext.version,
+ "enabled":ext.enabled
+ })
+ return ext_list
+
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
diff --git a/modules/api/models.py b/modules/api/models.py
index 6a574771..a0d80af8 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,12 +1,10 @@
import inspect
from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
-from typing_extensions import Literal
+from typing import Any, Optional, Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser
-from typing import Dict, List
API_NOT_ALLOWED = [
"self",
@@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@@ -168,17 +166,18 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
- items: dict = Field(title="Items", description="An object containing all the info the image had")
+ items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
+ parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
@@ -232,8 +231,8 @@ FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
- aliases: List[str] = Field(title="Aliases")
- options: Dict[str, str] = Field(title="Options")
+ aliases: list[str] = Field(title="Aliases")
+ options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
@@ -284,8 +283,8 @@ class EmbeddingItem(BaseModel):
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
- loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+ loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
@@ -303,11 +302,20 @@ class ScriptArg(BaseModel):
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
- choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+ choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
- args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+ args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+
+class ExtensionItem(BaseModel):
+ name: str = Field(title="Name", description="Extension name")
+ remote: str = Field(title="Remote", description="Extension Repository URL")
+ branch: str = Field(title="Branch", description="Extension Repository Branch")
+ commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
+ version: str = Field(title="Version", description="Extension Version")
+ commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
+ enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")