aboutsummaryrefslogtreecommitdiff
path: root/modules/api
diff options
context:
space:
mode:
authorarcticfaded <jbelt021@fiu.edu>2022-10-19 05:19:01 +0000
committerarcticfaded <jbelt021@fiu.edu>2022-10-19 05:19:01 +0000
commit0f0d6ab8e06898ce066251fc769fe14e77e98ced (patch)
treea8587d440fce92fc427128baee8aa645f63f687b /modules/api
parente7f4808505f7a6339927c32b9a0c01bc9134bdeb (diff)
call sampler by name
Diffstat (limited to 'modules/api')
-rw-r--r--modules/api/api.py11
-rw-r--r--modules/api/processing.py6
2 files changed, 9 insertions, 8 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index ff9df0d1..5b0c934e 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,6 +1,7 @@
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
-from modules.sd_samplers import samplers_k_diffusion
+from modules.sd_samplers import all_samplers
+from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter, HTTPException
@@ -10,7 +11,7 @@ import json
import io
import base64
-sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None)
+sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
@@ -53,13 +54,13 @@ class Api:
- def img2imgendoint(self):
+ def img2imgapi(self):
raise NotImplementedError
- def extrasendoint(self):
+ def extrasapi(self):
raise NotImplementedError
- def pnginfoendoint(self):
+ def pnginfoapi(self):
raise NotImplementedError
def launch(self, server_name, port):
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 2e6483ee..4c541241 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -1,7 +1,7 @@
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.processing import StableDiffusionProcessingTxt2Img
import inspect
@@ -95,5 +95,5 @@ class PydanticModelGenerator:
StableDiffusionProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
-).generate_model()
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model() \ No newline at end of file