aboutsummaryrefslogtreecommitdiff
path: root/modules/api/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api/api.py')
-rw-r--r--modules/api/api.py28
1 files changed, 8 insertions, 20 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 9814bbc2..5d60fc0a 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.extras import run_extras
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
@@ -45,10 +44,8 @@ def validate_sampler_name(name):
def setUpscalers(req: dict):
reqDict = vars(req)
- reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
- reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
- reqDict.pop('upscaler_1')
- reqDict.pop('upscaler_2')
+ reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
+ reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
def decode_base64_to_image(encoding):
@@ -126,8 +123,6 @@ class Api:
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
- self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
- self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
@@ -246,7 +241,7 @@ class Api:
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
- result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
@@ -262,7 +257,7 @@ class Api:
reqDict.pop('imageList')
with self.queue_lock:
- result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
@@ -390,12 +385,6 @@ class Api:
return styleList
- def get_artists_categories(self):
- return shared.artist_db.cats
-
- def get_artists(self):
- return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
-
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
@@ -480,7 +469,7 @@ class Api:
def train_hypernetwork(self, args: dict):
try:
shared.state.begin()
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -491,16 +480,15 @@ class Api:
except Exception as e:
error = e
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
- return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
except AssertionError as msg:
shared.state.end()
- return TrainResponse(info = "train embedding error: {error}".format(error = error))
+ return TrainResponse(info="train embedding error: {error}".format(error=error))
def get_memory(self):
try: