aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/gradio_extensons.py4
-rw-r--r--modules/initialize.py168
-rw-r--r--modules/initialize_util.py195
-rw-r--r--modules/shared_init.py3
-rw-r--r--modules/ui_extra_networks.py3
-rw-r--r--modules/ui_tempdir.py5
6 files changed, 371 insertions, 7 deletions
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
index 5af7fd8e..77c34c8b 100644
--- a/modules/gradio_extensons.py
+++ b/modules/gradio_extensons.py
@@ -1,6 +1,6 @@
import gradio as gr
-from modules import scripts
+from modules import scripts, ui_tempdir
def add_classes_to_gradio_component(comp):
"""
@@ -58,3 +58,5 @@ original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.components.IOComponent.__init__ = IOComponent_init
gr.blocks.Block.get_config = Block_get_config
gr.blocks.BlockContext.__init__ = BlockContext_init
+
+ui_tempdir.install_ui_tempdir_override()
diff --git a/modules/initialize.py b/modules/initialize.py
new file mode 100644
index 00000000..f24f7637
--- /dev/null
+++ b/modules/initialize.py
@@ -0,0 +1,168 @@
+import importlib
+import logging
+import sys
+import warnings
+from threading import Thread
+
+from modules.timer import startup_timer
+
+
+def imports():
+ logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
+
+ import torch # noqa: F401
+ startup_timer.record("import torch")
+ import pytorch_lightning # noqa: F401
+ startup_timer.record("import torch")
+ warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
+
+ import gradio # noqa: F401
+ startup_timer.record("import gradio")
+
+ from modules import paths, timer, import_hook, errors # noqa: F401
+ startup_timer.record("setup paths")
+
+ import ldm.modules.encoders.modules # noqa: F401
+ startup_timer.record("import ldm")
+
+ import sgm.modules.encoders.modules # noqa: F401
+ startup_timer.record("import sgm")
+
+ from modules import shared_init
+ shared_init.initialize()
+ startup_timer.record("initialize shared")
+
+ from modules import processing, gradio_extensons, ui # noqa: F401
+ startup_timer.record("other imports")
+
+
+def check_versions():
+ from modules.shared_cmd_options import cmd_opts
+
+ if not cmd_opts.skip_version_check:
+ from modules import errors
+ errors.check_versions()
+
+
+def initialize():
+ from modules import initialize_util
+ initialize_util.fix_torch_version()
+ initialize_util.fix_asyncio_event_loop_policy()
+ initialize_util.validate_tls_options()
+ initialize_util.configure_sigint_handler()
+ initialize_util.configure_opts_onchange()
+
+ from modules import modelloader
+ modelloader.cleanup_models()
+
+ from modules import sd_models
+ sd_models.setup_model()
+ startup_timer.record("setup SD model")
+
+ from modules.shared_cmd_options import cmd_opts
+
+ from modules import codeformer_model
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
+ codeformer_model.setup_model(cmd_opts.codeformer_models_path)
+ startup_timer.record("setup codeformer")
+
+ from modules import gfpgan_model
+ gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
+ startup_timer.record("setup gfpgan")
+
+ initialize_rest(reload_script_modules=False)
+
+
+def initialize_rest(*, reload_script_modules=False):
+ """
+ Called both from initialize() and when reloading the webui.
+ """
+ from modules.shared_cmd_options import cmd_opts
+
+ from modules import sd_samplers
+ sd_samplers.set_samplers()
+ startup_timer.record("set samplers")
+
+ from modules import extensions
+ extensions.list_extensions()
+ startup_timer.record("list extensions")
+
+ from modules import initialize_util
+ initialize_util.restore_config_state_file()
+ startup_timer.record("restore config state file")
+
+ from modules import shared, upscaler, scripts
+ if cmd_opts.ui_debug_mode:
+ shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
+ scripts.load_scripts()
+ return
+
+ from modules import sd_models
+ sd_models.list_models()
+ startup_timer.record("list SD models")
+
+ from modules import localization
+ localization.list_localizations(cmd_opts.localizations_dir)
+ startup_timer.record("list localizations")
+
+ with startup_timer.subcategory("load scripts"):
+ scripts.load_scripts()
+
+ if reload_script_modules:
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+ startup_timer.record("reload script modules")
+
+ from modules import modelloader
+ modelloader.load_upscalers()
+ startup_timer.record("load upscalers")
+
+ from modules import sd_vae
+ sd_vae.refresh_vae_list()
+ startup_timer.record("refresh VAE")
+
+ from modules import textual_inversion
+ textual_inversion.textual_inversion.list_textual_inversion_templates()
+ startup_timer.record("refresh textual inversion templates")
+
+ from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
+ script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
+ sd_hijack.list_optimizers()
+ startup_timer.record("scripts list_optimizers")
+
+ from modules import sd_unet
+ sd_unet.list_unets()
+ startup_timer.record("scripts list_unets")
+
+ def load_model():
+ """
+ Accesses shared.sd_model property to load model.
+ After it's available, if it has been loaded before this access by some extension,
+ its optimization may be None because the list of optimizaers has neet been filled
+ by that time, so we apply optimization again.
+ """
+
+ shared.sd_model # noqa: B018
+
+ if sd_hijack.current_optimizer is None:
+ sd_hijack.apply_optimizations()
+
+ from modules import devices
+ devices.first_time_calculation()
+
+ Thread(target=load_model).start()
+
+ from modules import shared_items
+ shared_items.reload_hypernetworks()
+ startup_timer.record("reload hypernetworks")
+
+ from modules import ui_extra_networks
+ ui_extra_networks.initialize()
+ ui_extra_networks.register_default_pages()
+
+ from modules import extra_networks
+ extra_networks.initialize()
+ extra_networks.register_default_extra_networks()
+ startup_timer.record("initialize extra networks")
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
new file mode 100644
index 00000000..e59bd3c4
--- /dev/null
+++ b/modules/initialize_util.py
@@ -0,0 +1,195 @@
+import json
+import logging
+import os
+import signal
+import sys
+import re
+
+from modules.timer import startup_timer
+
+def setup_logging():
+ # We can't use cmd_opts for this because it will not have been initialized at this point.
+ log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
+ if log_level:
+ log_level = getattr(logging, log_level.upper(), None) or logging.INFO
+ logging.basicConfig(
+ level=log_level,
+ format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ )
+
+
+def gradio_server_name():
+ from modules.shared_cmd_options import cmd_opts
+
+ if cmd_opts.server_name:
+ return cmd_opts.server_name
+ else:
+ return "0.0.0.0" if cmd_opts.listen else None
+
+
+def fix_torch_version():
+ import torch
+
+ # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
+ if ".dev" in torch.__version__ or "+git" in torch.__version__:
+ torch.__long_version__ = torch.__version__
+ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
+
+
+def fix_asyncio_event_loop_policy():
+ """
+ The default `asyncio` event loop policy only automatically creates
+ event loops in the main threads. Other threads must create event
+ loops explicitly or `asyncio.get_event_loop` (and therefore
+ `.IOLoop.current`) will fail. Installing this policy allows event
+ loops to be created automatically on any thread, matching the
+ behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
+ """
+
+ import asyncio
+
+ if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
+ # "Any thread" and "selector" should be orthogonal, but there's not a clean
+ # interface for composing policies so pick the right base.
+ _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
+ else:
+ _BasePolicy = asyncio.DefaultEventLoopPolicy
+
+ class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
+ """Event loop policy that allows loop creation on any thread.
+ Usage::
+
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
+ """
+
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
+ try:
+ return super().get_event_loop()
+ except (RuntimeError, AssertionError):
+ # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
+ # and changed to a RuntimeError in 3.4.3.
+ # "There is no current event loop in thread %r"
+ loop = self.new_event_loop()
+ self.set_event_loop(loop)
+ return loop
+
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
+
+
+def restore_config_state_file():
+ from modules import shared, config_states
+
+ config_state_file = shared.opts.restore_config_state_file
+ if config_state_file == "":
+ return
+
+ shared.opts.restore_config_state_file = ""
+ shared.opts.save(shared.config_filename)
+
+ if os.path.isfile(config_state_file):
+ print(f"*** About to restore extension state from file: {config_state_file}")
+ with open(config_state_file, "r", encoding="utf-8") as f:
+ config_state = json.load(f)
+ config_states.restore_extension_config(config_state)
+ startup_timer.record("restore extension config")
+ elif config_state_file:
+ print(f"!!! Config state backup not found: {config_state_file}")
+
+
+def validate_tls_options():
+ from modules.shared_cmd_options import cmd_opts
+
+ if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
+ return
+
+ try:
+ if not os.path.exists(cmd_opts.tls_keyfile):
+ print("Invalid path to TLS keyfile given")
+ if not os.path.exists(cmd_opts.tls_certfile):
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+ except TypeError:
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+ print("TLS setup invalid, running webui without TLS")
+ else:
+ print("Running with TLS")
+ startup_timer.record("TLS")
+
+
+def get_gradio_auth_creds():
+ """
+ Convert the gradio_auth and gradio_auth_path commandline arguments into
+ an iterable of (username, password) tuples.
+ """
+ from modules.shared_cmd_options import cmd_opts
+
+ def process_credential_line(s):
+ s = s.strip()
+ if not s:
+ return None
+ return tuple(s.split(':', 1))
+
+ if cmd_opts.gradio_auth:
+ for cred in cmd_opts.gradio_auth.split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+ if cmd_opts.gradio_auth_path:
+ with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ for cred in line.strip().split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+
+def configure_sigint_handler():
+ # make the program just exit at ctrl+c without waiting for anything
+ def sigint_handler(sig, frame):
+ print(f'Interrupted with signal {sig} in {frame}')
+ os._exit(0)
+
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
+
+
+def configure_opts_onchange():
+ from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
+ from modules.call_queue import wrap_queued_call
+
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
+ shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ startup_timer.record("opts onchange")
+
+
+def setup_middleware(app):
+ from starlette.middleware.gzip import GZipMiddleware
+
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
+ configure_cors_middleware(app)
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
+
+
+def configure_cors_middleware(app):
+ from starlette.middleware.cors import CORSMiddleware
+ from modules.shared_cmd_options import cmd_opts
+
+ cors_options = {
+ "allow_methods": ["*"],
+ "allow_headers": ["*"],
+ "allow_credentials": True,
+ }
+ if cmd_opts.cors_allow_origins:
+ cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
+ if cmd_opts.cors_allow_origins_regex:
+ cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
+ app.add_middleware(CORSMiddleware, **cors_options)
+
diff --git a/modules/shared_init.py b/modules/shared_init.py
index b88d1d8e..d3fb687e 100644
--- a/modules/shared_init.py
+++ b/modules/shared_init.py
@@ -5,9 +5,6 @@ import torch
from modules import shared
from modules.shared import cmd_opts
-import sys
-sys.setrecursionlimit(1000)
-
def initialize():
"""Initializes fields inside the shared module in a controlled manner.
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index e0b932b9..16d76a45 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -4,7 +4,6 @@ from pathlib import Path
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules.images import read_info_from_image, save_image_with_geninfo
-from modules.ui import up_down_symbol
import gradio as gr
import json
import html
@@ -348,6 +347,8 @@ def pages_in_preferred_order(pages):
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
+ from modules.ui import up_down_symbol
+
ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index fb75137e..506017e5 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -57,8 +57,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
return file_obj.name
-# override save to file function so that it also writes PNG info
-gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
+def install_ui_tempdir_override():
+ """override save to file function so that it also writes PNG info"""
+ gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
def on_tmpdir_changed():