aboutsummaryrefslogtreecommitdiff
path: root/modules/api/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api/api.py')
-rw-r--r--modules/api/api.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index fb2c2ce9..6e8d21a3 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -23,8 +23,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
-from modules.sd_vae import vae_dict
+from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -57,6 +56,15 @@ def setUpscalers(req: dict):
def decode_base64_to_image(encoding):
+ if encoding.startswith("http://") or encoding.startswith("https://"):
+ import requests
+ response = requests.get(encoding, timeout=30, headers={'user-agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.94 Safari/537.36'})
+ try:
+ image = Image.open(BytesIO(response.content))
+ return image
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Invalid image url") from e
+
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
@@ -567,10 +575,12 @@ class Api:
]
def get_sd_models(self):
- return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+ import modules.sd_models as sd_models
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
def get_sd_vaes(self):
- return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
+ import modules.sd_vae as sd_vae
+ return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]