aboutsummaryrefslogtreecommitdiff
path: root/modules/api/api.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-10-15 09:41:02 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-10-15 09:41:02 +0300
commit282903bb6798f49af66f6935ee4dc0015895cf7c (patch)
treec6bb9c1d4f298fe7c3b8d797cd7491fa873badf0 /modules/api/api.py
parent0d65d0eabded4a79c3168ed14599db3970d414c5 (diff)
repair unload sd checkpoint button
Diffstat (limited to 'modules/api/api.py')
-rw-r--r--modules/api/api.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index efedafa4..09083874 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
-from PIL import PngImagePlugin,Image
-from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
+from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -541,12 +540,12 @@ class Api:
return {}
def unloadapi(self):
- unload_model_weights()
+ sd_models.unload_model_weights()
return {}
def reloadapi(self):
- reload_model_weights()
+ sd_models.send_model_to_device(shared.sd_model)
return {}
@@ -566,7 +565,7 @@ class Api:
def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
- if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+ if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():