aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py102
1 files changed, 58 insertions, 44 deletions
diff --git a/webui.py b/webui.py
index e8f0a63d..cebfba96 100644
--- a/webui.py
+++ b/webui.py
@@ -8,7 +8,7 @@ import warnings
import json
from threading import Thread
-from fastapi import FastAPI
+from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
@@ -16,12 +16,12 @@ from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors
+from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer = timer.Timer()
import torch
-import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
+import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
@@ -31,19 +31,19 @@ startup_timer.record("import torch")
import gradio
startup_timer.record("import gradio")
-import ldm.modules.encoders.modules
+import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -144,16 +144,11 @@ Use --skip-version-check commandline argument to disable this check.
""".strip())
-def initialize():
- fix_asyncio_event_loop_policy()
-
- check_versions()
-
- extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
- startup_timer.record("list extensions")
-
+def restore_config_state_file():
config_state_file = shared.opts.restore_config_state_file
+ if config_state_file == "":
+ return
+
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
@@ -166,6 +161,18 @@ def initialize():
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
+
+def initialize():
+ fix_asyncio_event_loop_policy()
+
+ check_versions()
+
+ extensions.list_extensions()
+ localization.list_localizations(cmd_opts.localizations_dir)
+ startup_timer.record("list extensions")
+
+ restore_config_state_file()
+
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
@@ -185,7 +192,7 @@ def initialize():
startup_timer.record("load scripts")
modelloader.load_upscalers()
- startup_timer.record("load upscalers") #Is this necessary? I don't know.
+ startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
@@ -234,7 +241,10 @@ def initialize():
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
- signal.signal(signal.SIGINT, sigint_handler)
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
def setup_middleware(app):
@@ -255,19 +265,6 @@ def create_api(app):
return api
-def wait_on_server(demo=None):
- while 1:
- time.sleep(0.5)
- if shared.state.need_restart:
- shared.state.need_restart = False
- time.sleep(0.5)
- demo.close()
- time.sleep(0.5)
-
- modules.script_callbacks.app_reload_callback()
- break
-
-
def api_only():
initialize()
@@ -280,6 +277,12 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
+
+def stop_route(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
+
+
def webui():
launch_api = cmd_opts.api
initialize()
@@ -328,6 +331,9 @@ def webui():
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
+ if cmd_opts.add_stop_route:
+ app.add_route("/_stop", stop_route, methods=["POST"])
+
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
@@ -357,10 +363,29 @@ def webui():
if cmd_opts.subpath:
redirector = FastAPI()
redirector.get("/")
- mounted_app = gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
+ gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
- wait_on_server(shared.demo)
+ try:
+ while True:
+ server_command = shared.state.wait_for_server_command(timeout=5)
+ if server_command:
+ if server_command in ("stop", "restart"):
+ break
+ else:
+ print(f"Unknown server command: {server_command}")
+ except KeyboardInterrupt:
+ print('Caught KeyboardInterrupt, stopping...')
+ server_command = "stop"
+
+ if server_command == "stop":
+ print("Stopping server...")
+ # If we catch a keyboard interrupt, we want to stop the server and exit.
+ shared.demo.close()
+ break
print('Restarting UI...')
+ shared.demo.close()
+ time.sleep(0.5)
+ modules.script_callbacks.app_reload_callback()
startup_timer.reset()
@@ -370,18 +395,7 @@ def webui():
extensions.list_extensions()
startup_timer.record("list extensions")
- config_state_file = shared.opts.restore_config_state_file
- shared.opts.restore_config_state_file = ""
- shared.opts.save(shared.config_filename)
-
- if os.path.isfile(config_state_file):
- print(f"*** About to restore extension state from file: {config_state_file}")
- with open(config_state_file, "r", encoding="utf-8") as f:
- config_state = json.load(f)
- config_states.restore_extension_config(config_state)
- startup_timer.record("restore extension config")
- elif config_state_file:
- print(f"!!! Config state backup not found: {config_state_file}")
+ restore_config_state_file()
localization.list_localizations(cmd_opts.localizations_dir)