aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-11-04 10:46:51 +0300
committerGitHub <noreply@github.com>2022-11-04 10:46:51 +0300
commit371c4b990eca3f2418e62ce8c852e9a52d39e445 (patch)
tree52fdba21c023885985bd6b7ef2f5297d4f88f012
parentf674c488d9701e577e2aaf25e331fb44ada4f1ef (diff)
parent17bd3f4ea730436599849eddbaa78e2879b793d2 (diff)
Merge pull request #4218 from bamarillo/utils-endpoints
[API][Feature] Utils endpoints
-rw-r--r--modules/api/api.py85
-rw-r--r--modules/api/models.py70
-rw-r--r--test/utils_test.py63
3 files changed, 210 insertions, 8 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 71c9c160..a49f3755 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -2,14 +2,17 @@ import base64
import io
import time
import uvicorn
-from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
-from fastapi import APIRouter, Depends, HTTPException
+from threading import Lock
+from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
+from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
-
+from modules.sd_models import checkpoints_list
+from modules.realesrgan_model import get_realesrgan_models
+from typing import List
def upscaler_to_index(name: str):
try:
@@ -37,7 +40,7 @@ def encode_pil_to_base64(image):
class Api:
- def __init__(self, app, queue_lock):
+ def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
@@ -48,6 +51,18 @@ class Api:
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+ self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
+ self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
+ self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
+ self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
+ self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
+ self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
+ self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
+ self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
+ self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
+ self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -190,6 +205,66 @@ class Api:
shared.state.interrupt()
return {}
+
+ def get_config(self):
+ options = {}
+ for key in shared.opts.data.keys():
+ metadata = shared.opts.data_labels.get(key)
+ if(metadata is not None):
+ options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
+ else:
+ options.update({key: shared.opts.data.get(key, None)})
+
+ return options
+
+ def set_config(self, req: OptionsModel):
+ reqDict = vars(req)
+ for o in reqDict:
+ setattr(shared.opts, o, reqDict[o])
+
+ shared.opts.save(shared.config_filename)
+ return
+
+ def get_cmd_flags(self):
+ return vars(shared.cmd_opts)
+
+ def get_samplers(self):
+ return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
+
+ def get_upscalers(self):
+ upscalers = []
+
+ for upscaler in shared.sd_upscalers:
+ u = upscaler.scaler
+ upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
+
+ return upscalers
+
+ def get_sd_models(self):
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
+
+ def get_hypernetworks(self):
+ return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
+
+ def get_face_restorers(self):
+ return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
+
+ def get_realesrgan_models(self):
+ return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
+
+ def get_promp_styles(self):
+ styleList = []
+ for k in shared.prompt_styles.styles:
+ style = shared.prompt_styles.styles[k]
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
+
+ return styleList
+
+ def get_artists_categories(self):
+ return shared.artist_db.cats
+
+ def get_artists(self):
+ return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 68fb45c6..8933e183 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,11 +1,10 @@
import inspect
-from click import prompt
from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
+from typing import Any, Optional, Union
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-from modules.shared import sd_upscalers
+from modules.shared import sd_upscalers, opts, parser
API_NOT_ALLOWED = [
"self",
@@ -166,3 +165,68 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+
+fields = {}
+for key, value in opts.data.items():
+ metadata = opts.data_labels.get(key)
+ optType = opts.typemap.get(type(value), type(value))
+
+ if (metadata is not None):
+ fields.update({key: (Optional[optType], Field(
+ default=metadata.default ,description=metadata.label))})
+ else:
+ fields.update({key: (Optional[optType], Field())})
+
+OptionsModel = create_model("Options", **fields)
+
+flags = {}
+_options = vars(parser)['_option_string_actions']
+for key in _options:
+ if(_options[key].dest != 'help'):
+ flag = _options[key]
+ _type = str
+ if(_options[key].default != None): _type = type(_options[key].default)
+ flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+
+FlagsModel = create_model("Flags", **flags)
+
+class SamplerItem(BaseModel):
+ name: str = Field(title="Name")
+ aliases: list[str] = Field(title="Aliases")
+ options: dict[str, str] = Field(title="Options")
+
+class UpscalerItem(BaseModel):
+ name: str = Field(title="Name")
+ model_name: str | None = Field(title="Model Name")
+ model_path: str | None = Field(title="Path")
+ model_url: str | None = Field(title="URL")
+
+class SDModelItem(BaseModel):
+ title: str = Field(title="Title")
+ model_name: str = Field(title="Model Name")
+ hash: str = Field(title="Hash")
+ filename: str = Field(title="Filename")
+ config: str = Field(title="Config file")
+
+class HypernetworkItem(BaseModel):
+ name: str = Field(title="Name")
+ path: str | None = Field(title="Path")
+
+class FaceRestorerItem(BaseModel):
+ name: str = Field(title="Name")
+ cmd_dir: str | None = Field(title="Path")
+
+class RealesrganItem(BaseModel):
+ name: str = Field(title="Name")
+ path: str | None = Field(title="Path")
+ scale: int | None = Field(title="Scale")
+
+class PromptStyleItem(BaseModel):
+ name: str = Field(title="Name")
+ prompt: str | None = Field(title="Prompt")
+ negative_prompt: str | None = Field(title="Negative Prompt")
+
+class ArtistItem(BaseModel):
+ name: str = Field(title="Name")
+ score: float = Field(title="Score")
+ category: str = Field(title="Category") \ No newline at end of file
diff --git a/test/utils_test.py b/test/utils_test.py
new file mode 100644
index 00000000..65d3d177
--- /dev/null
+++ b/test/utils_test.py
@@ -0,0 +1,63 @@
+import unittest
+import requests
+
+class UtilsTests(unittest.TestCase):
+ def setUp(self):
+ self.url_options = "http://localhost:7860/sdapi/v1/options"
+ self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
+ self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
+ self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
+ self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
+ self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
+ self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
+ self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
+ self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
+ self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
+ self.url_artists = "http://localhost:7860/sdapi/v1/artists"
+
+ def test_options_get(self):
+ self.assertEqual(requests.get(self.url_options).status_code, 200)
+
+ def test_options_write(self):
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
+
+ pre_value = response.json()["send_seed"]
+
+ self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
+
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.json()["send_seed"], not pre_value)
+
+ requests.post(self.url_options, json={"send_seed": pre_value})
+
+ def test_cmd_flags(self):
+ self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
+
+ def test_samplers(self):
+ self.assertEqual(requests.get(self.url_samplers).status_code, 200)
+
+ def test_upscalers(self):
+ self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
+
+ def test_sd_models(self):
+ self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
+
+ def test_hypernetworks(self):
+ self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
+
+ def test_face_restorers(self):
+ self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
+
+ def test_realesrgan_models(self):
+ self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
+
+ def test_prompt_styles(self):
+ self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
+
+ def test_artist_categories(self):
+ self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
+
+ def test_artists(self):
+ self.assertEqual(requests.get(self.url_artists).status_code, 200) \ No newline at end of file