aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py128
1 files changed, 106 insertions, 22 deletions
diff --git a/webui.py b/webui.py
index dda34249..ddccf870 100644
--- a/webui.py
+++ b/webui.py
@@ -1,19 +1,29 @@
import os
-import threading
+import sys
import time
import importlib
import signal
-import threading
+import re
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
+from packaging import version
+import logging
+logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
+
+from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
+from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
-from modules.paths import script_path
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
+import torch
+
+# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
+if ".dev" in torch.__version__ or "+git" in torch.__version__:
+ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
+
+from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer
-import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
@@ -26,6 +36,8 @@ import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
+import modules.textual_inversion.textual_inversion
+import modules.progress
import modules.ui
from modules import modelloader
@@ -39,7 +51,40 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None
+def check_versions():
+ if shared.cmd_opts.skip_version_check:
+ return
+
+ expected_torch_version = "1.13.1"
+
+ if version.parse(torch.__version__) < version.parse(expected_torch_version):
+ errors.print_error_explanation(f"""
+You are running torch {torch.__version__}.
+The program is tested to work with torch {expected_torch_version}.
+To reinstall the desired version, run with commandline flag --reinstall-torch.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ expected_xformers_version = "0.0.16rc425"
+ if shared.xformers_available:
+ import xformers
+
+ if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
+ errors.print_error_explanation(f"""
+You are running xformers {xformers.__version__}.
+The program is tested to work with xformers {expected_xformers_version}.
+To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+
def initialize():
+ check_versions()
+
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
@@ -54,19 +99,39 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
+ modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
-
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
- modules.sd_models.load_model()
+
+ modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
+
+ try:
+ modules.sd_models.load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+ exit(1)
+
+ shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
- shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.reload_hypernetworks()
+
+ ui_extra_networks.intialize()
+ ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+ ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+ ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
+
+ extra_networks.initialize()
+ extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try:
@@ -90,11 +155,11 @@ def initialize():
def setup_cors(app):
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
+ app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
+ app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
+ app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
def create_api(app):
@@ -137,9 +202,14 @@ def webui():
if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
+ modules.script_callbacks.before_ui_callback()
+
shared.demo = modules.ui.create_ui()
- app, local_url, share_url = shared.demo.queue(default_enabled=False).launch(
+ if cmd_opts.gradio_queue:
+ shared.demo.queue(64)
+
+ app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port,
@@ -155,38 +225,52 @@ def webui():
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
- # running web ui and do whatever the attcker wants, including installing an extension and
- # runnnig its code. We disable this here. Suggested by RyotaK.
+ # running web ui and do whatever the attacker wants, including installing an extension and
+ # running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000)
+ modules.progress.setup_progress_api(app)
+
if launch_api:
create_api(app)
- modules.script_callbacks.app_started_callback(shared.demo, app)
+ ui_extra_networks.add_pages_to_demo(app)
+
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
+ print('Restarting UI...')
sd_samplers.set_samplers()
- print('Reloading extensions')
+ modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
- print('Reloading custom scripts')
+ modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
+ modules.script_callbacks.model_loaded_callback(shared.sd_model)
modelloader.load_upscalers()
- print('Reloading modules: modules.ui')
- importlib.reload(modules.ui)
- print('Refreshing Model List')
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+
modules.sd_models.list_models()
- print('Restarting Gradio')
+
+ shared.reload_hypernetworks()
+
+ ui_extra_networks.intialize()
+ ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+ ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+ ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
+
+ extra_networks.initialize()
+ extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
if __name__ == "__main__":