diff options
author | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-07 22:43:38 +0700 |
---|---|---|
committer | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-07 22:43:38 +0700 |
commit | cabd4e3b3bf91e0cb5071398a8efddef495f6311 (patch) | |
tree | 55daa888a7e03e2e204daf6729835b94277350a2 /modules/api | |
parent | bb832d7725187f8a8ab44faa6ee1b38cb5f600aa (diff) | |
parent | 804d9fb83d0c63ca3acd36378707ce47b8f12599 (diff) |
Merge branch 'master' into gradient-clipping
Diffstat (limited to 'modules/api')
-rw-r--r-- | modules/api/api.py | 53 | ||||
-rw-r--r-- | modules/api/models.py | 40 |
2 files changed, 66 insertions, 27 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index a49f3755..688469ad 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,6 +10,7 @@ from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo +from PIL import PngImagePlugin from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List @@ -34,9 +35,21 @@ def setUpscalers(req: dict): def encode_pil_to_base64(image): - buffer = io.BytesIO() - image.save(buffer, format="png") - return base64.b64encode(buffer.getvalue()) + with io.BytesIO() as output_bytes: + + # Copy any text-only metadata + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + image.save( + output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) + ) + bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) class Api: @@ -50,6 +63,7 @@ class Api: self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) + self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) @@ -201,11 +215,24 @@ class Api: return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) + def interrogateapi(self, interrogatereq: InterrogateRequest): + image_b64 = interrogatereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + img = self.__base64_to_image(image_b64) + + # Override object param + with self.queue_lock: + processed = shared.interrogator.interrogate(img) + + return InterrogateResponse(caption=processed) + def interruptapi(self): shared.state.interrupt() return {} - + def get_config(self): options = {} for key in shared.opts.data.keys(): @@ -214,10 +241,14 @@ class Api: options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)}) else: options.update({key: shared.opts.data.get(key, None)}) - + return options - + def set_config(self, req: OptionsModel): + # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will + # overwrite all options with default values. + raise RuntimeError('Setting options via API is not supported') + reqDict = vars(req) for o in reqDict: setattr(shared.opts, o, reqDict[o]) @@ -233,13 +264,13 @@ class Api: def get_upscalers(self): upscalers = [] - + for upscaler in shared.sd_upscalers: u = upscaler.scaler upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url}) - + return upscalers - + def get_sd_models(self): return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] @@ -251,11 +282,11 @@ class Api: def get_realesrgan_models(self): return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)] - + def get_promp_styles(self): styleList = [] for k in shared.prompt_styles.styles: - style = shared.prompt_styles.styles[k] + style = shared.prompt_styles.styles[k] styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]}) return styleList diff --git a/modules/api/models.py b/modules/api/models.py index 2ae75f43..34dbfa16 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,11 +1,11 @@ import inspect from pydantic import BaseModel, Field, create_model -from typing import Any, Optional, Union +from typing import Any, Optional from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers, opts, parser -from typing import List +from typing import Dict, List API_NOT_ALLOWED = [ "self", @@ -65,6 +65,7 @@ class PydanticModelGenerator: self._model_name = model_name self._class_data = merge_class_params(class_instance) + self._model_def = [ ModelDef( field=underscore(k), @@ -167,6 +168,12 @@ class ProgressResponse(BaseModel): state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") +class InterrogateRequest(BaseModel): + image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + +class InterrogateResponse(BaseModel): + caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + fields = {} for key, value in opts.data.items(): metadata = opts.data_labels.get(key) @@ -185,22 +192,22 @@ _options = vars(parser)['_option_string_actions'] for key in _options: if(_options[key].dest != 'help'): flag = _options[key] - _type = str - if(_options[key].default != None): _type = type(_options[key].default) + _type = str + if _options[key].default is not None: _type = type(_options[key].default) flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) FlagsModel = create_model("Flags", **flags) class SamplerItem(BaseModel): name: str = Field(title="Name") - aliases: list[str] = Field(title="Aliases") - options: dict[str, str] = Field(title="Options") + aliases: List[str] = Field(title="Aliases") + options: Dict[str, str] = Field(title="Options") class UpscalerItem(BaseModel): name: str = Field(title="Name") - model_name: str | None = Field(title="Model Name") - model_path: str | None = Field(title="Path") - model_url: str | None = Field(title="URL") + model_name: Optional[str] = Field(title="Model Name") + model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") class SDModelItem(BaseModel): title: str = Field(title="Title") @@ -211,23 +218,24 @@ class SDModelItem(BaseModel): class HypernetworkItem(BaseModel): name: str = Field(title="Name") - path: str | None = Field(title="Path") + path: Optional[str] = Field(title="Path") class FaceRestorerItem(BaseModel): name: str = Field(title="Name") - cmd_dir: str | None = Field(title="Path") + cmd_dir: Optional[str] = Field(title="Path") class RealesrganItem(BaseModel): name: str = Field(title="Name") - path: str | None = Field(title="Path") - scale: int | None = Field(title="Scale") + path: Optional[str] = Field(title="Path") + scale: Optional[int] = Field(title="Scale") class PromptStyleItem(BaseModel): name: str = Field(title="Name") - prompt: str | None = Field(title="Prompt") - negative_prompt: str | None = Field(title="Negative Prompt") + prompt: Optional[str] = Field(title="Prompt") + negative_prompt: Optional[str] = Field(title="Negative Prompt") class ArtistItem(BaseModel): name: str = Field(title="Name") score: float = Field(title="Score") - category: str = Field(title="Category")
\ No newline at end of file + category: str = Field(title="Category") + |