aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-05-14 11:46:27 +0300
committerGitHub <noreply@github.com>2023-05-14 11:46:27 +0300
commit80adb6979d46bbb832254004cac4f4f9bec9efb3 (patch)
treee206ee60f9be21e9e20d483213b7d0a610d2bdbd /webui.py
parent1dcd6723242c3d691610f9ed937951baea49c2d1 (diff)
parent3ddc76342298ad0b2d14cb571ceb48c0b0c4176d (diff)
Merge branch 'dev' into find_vae
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py50
1 files changed, 29 insertions, 21 deletions
diff --git a/webui.py b/webui.py
index 357bf4c1..293a16cc 100644
--- a/webui.py
+++ b/webui.py
@@ -6,6 +6,8 @@ import signal
import re
import warnings
import json
+from threading import Thread
+
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -14,12 +16,12 @@ from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors
+from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer = timer.Timer()
import torch
-import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
+import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
@@ -29,19 +31,19 @@ startup_timer.record("import torch")
import gradio
startup_timer.record("import gradio")
-import ldm.modules.encoders.modules
+import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -179,30 +181,22 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
- modelloader.list_builtin_upscalers()
- startup_timer.record("list builtin upscalers")
-
modules.scripts.load_scripts()
startup_timer.record("load scripts")
+ modelloader.load_upscalers()
+ startup_timer.record("load upscalers")
+
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
- try:
- modules.sd_models.load_model()
- except Exception as e:
- errors.display(e, "loading stable diffusion model")
- print("", file=sys.stderr)
- print("Stable diffusion model failed to load, exiting", file=sys.stderr)
- exit(1)
- startup_timer.record("load SD checkpoint")
-
- shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+ # load model in parallel to other startup stuff
+ Thread(target=lambda: shared.sd_model).start()
- shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
@@ -286,7 +280,6 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-
def webui():
launch_api = cmd_opts.api
initialize()
@@ -313,6 +306,16 @@ def webui():
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+ # this restores the missing /docs endpoint
+ if launch_api and not hasattr(FastAPI, 'original_setup'):
+ def fastapi_setup(self):
+ self.docs_url = "/docs"
+ self.redoc_url = "/redoc"
+ self.original_setup()
+
+ FastAPI.original_setup = FastAPI.setup
+ FastAPI.setup = fastapi_setup
+
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
@@ -339,6 +342,7 @@ def webui():
setup_middleware(app)
modules.progress.setup_progress_api(app)
+ modules.ui.setup_ui_api(app)
if launch_api:
create_api(app)
@@ -350,6 +354,11 @@ def webui():
print(f"Startup time: {startup_timer.summary()}.")
+ if cmd_opts.subpath:
+ redirector = FastAPI()
+ redirector.get("/")
+ gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
+
wait_on_server(shared.demo)
print('Restarting UI...')
@@ -376,7 +385,6 @@ def webui():
localization.list_localizations(cmd_opts.localizations_dir)
- modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
startup_timer.record("load scripts")