aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py32
1 files changed, 20 insertions, 12 deletions
diff --git a/webui.py b/webui.py
index 3aee8792..d89e0fb5 100644
--- a/webui.py
+++ b/webui.py
@@ -1,4 +1,5 @@
import os
+import sys
import threading
import time
import importlib
@@ -8,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from modules import import_hook
+from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
@@ -55,12 +56,20 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
+ modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
-
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
- modules.sd_models.load_model()
+
+ try:
+ modules.sd_models.load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+ exit(1)
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
@@ -91,11 +100,11 @@ def initialize():
def setup_cors(app):
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
+ app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
+ app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
+ app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
def create_api(app):
@@ -169,23 +178,22 @@ def webui():
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
+ print('Restarting UI...')
sd_samplers.set_samplers()
- print('Reloading extensions')
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
- print('Reloading custom scripts')
+ modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
modelloader.load_upscalers()
- print('Reloading modules: modules.ui')
- importlib.reload(modules.ui)
- print('Refreshing Model List')
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+
modules.sd_models.list_models()
- print('Restarting Gradio')
if __name__ == "__main__":