aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorw-e-w <40751091+w-e-w@users.noreply.github.com>2023-08-08 11:39:34 +0900
committerGitHub <noreply@github.com>2023-08-08 11:39:34 +0900
commitf17c8c2eff63210f5e96e1e2b049b46ba9cfa389 (patch)
tree701056aec9ae11bc45df9b39b176a54fa4d34e19 /webui.py
parentc75bda867be5345bf959daf23bdc19eadc90841a (diff)
parent01997f45ba089af24b03a5f614147bb0f9d8d824 (diff)
Merge branch 'dev' into auro-autolaunch
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py46
1 files changed, 10 insertions, 36 deletions
diff --git a/webui.py b/webui.py
index 844e2548..6d36f880 100644
--- a/webui.py
+++ b/webui.py
@@ -14,7 +14,6 @@ from typing import Iterable
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from packaging import version
import logging
@@ -50,6 +49,7 @@ startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
+
from modules import extra_networks
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
@@ -58,10 +58,15 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared
+
+if not shared.cmd_opts.skip_version_check:
+ errors.check_versions()
+
import modules.codeformer_model as codeformer
-import modules.face_restoration
import modules.gfpgan_model as gfpgan
+from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+import modules.face_restoration
import modules.img2img
import modules.lowvram
@@ -130,37 +135,6 @@ def fix_asyncio_event_loop_policy():
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
-def check_versions():
- if shared.cmd_opts.skip_version_check:
- return
-
- expected_torch_version = "2.0.0"
-
- if version.parse(torch.__version__) < version.parse(expected_torch_version):
- errors.print_error_explanation(f"""
-You are running torch {torch.__version__}.
-The program is tested to work with torch {expected_torch_version}.
-To reinstall the desired version, run with commandline flag --reinstall-torch.
-Beware that this will cause a lot of large files to be downloaded, as well as
-there are reports of issues with training tab on the latest version.
-
-Use --skip-version-check commandline argument to disable this check.
- """.strip())
-
- expected_xformers_version = "0.0.20"
- if shared.xformers_available:
- import xformers
-
- if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
- errors.print_error_explanation(f"""
-You are running xformers {xformers.__version__}.
-The program is tested to work with xformers {expected_xformers_version}.
-To reinstall the desired version, run with commandline flag --reinstall-xformers.
-
-Use --skip-version-check commandline argument to disable this check.
- """.strip())
-
-
def restore_config_state_file():
config_state_file = shared.opts.restore_config_state_file
if config_state_file == "":
@@ -237,7 +211,7 @@ def configure_sigint_handler():
def configure_opts_onchange():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
@@ -248,7 +222,6 @@ def initialize():
fix_asyncio_event_loop_policy()
validate_tls_options()
configure_sigint_handler()
- check_versions()
modelloader.cleanup_models()
configure_opts_onchange()
@@ -368,6 +341,7 @@ def api_only():
setup_middleware(app)
api = create_api(app)
+ modules.script_callbacks.before_ui_callback()
modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")