aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--launch.py8
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/launch_utils.py31
-rw-r--r--modules/sd_disable_initialization.py106
-rw-r--r--modules/sd_models.py16
-rw-r--r--modules/timer.py23
-rw-r--r--webui.py4
7 files changed, 165 insertions, 25 deletions
diff --git a/launch.py b/launch.py
index 1dbc4c6e..e4c2ce99 100644
--- a/launch.py
+++ b/launch.py
@@ -1,6 +1,5 @@
from modules import launch_utils
-
args = launch_utils.args
python = launch_utils.python
git = launch_utils.git
@@ -26,8 +25,11 @@ start = launch_utils.start
def main():
- if not args.skip_prepare_environment:
- prepare_environment()
+ launch_utils.startup_timer.record("initial startup")
+
+ with launch_utils.startup_timer.subcategory("prepare environment"):
+ if not args.skip_prepare_environment:
+ prepare_environment()
if args.test_server:
configure_for_tests()
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index e401f641..cb4ec5f7 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -13,6 +13,7 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
+parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
@@ -66,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index e1c9cfbe..f77b577a 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -10,9 +10,7 @@ from functools import lru_cache
from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
-from modules import timer
-
-timer.startup_timer.record("start")
+from modules.timer import startup_timer
args, _ = cmd_args.parser.parse_known_args()
@@ -226,8 +224,13 @@ def run_extensions_installers(settings_file):
if not os.path.isdir(extensions_dir):
return
- for dirname_extension in list_extensions(settings_file):
- run_extension_installer(os.path.join(extensions_dir, dirname_extension))
+ with startup_timer.subcategory("run extensions installers"):
+ for dirname_extension in list_extensions(settings_file):
+ path = os.path.join(extensions_dir, dirname_extension)
+
+ if os.path.isdir(path):
+ run_extension_installer(path)
+ startup_timer.record(dirname_extension)
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
@@ -300,8 +303,11 @@ def prepare_environment():
if not args.skip_python_version_check:
check_python_version()
+ startup_timer.record("checks")
+
commit = commit_hash()
tag = git_tag()
+ startup_timer.record("git version info")
print(f"Python {sys.version}")
print(f"Version: {tag}")
@@ -309,21 +315,27 @@ def prepare_environment():
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
+ startup_timer.record("install torch")
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError(
'Torch is not able to use GPU; '
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
)
+ startup_timer.record("torch GPU test")
+
if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
+ startup_timer.record("install gfpgan")
if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
+ startup_timer.record("install clip")
if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip")
+ startup_timer.record("install open_clip")
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
if platform.system() == "Windows":
@@ -337,8 +349,11 @@ def prepare_environment():
elif platform.system() == "Linux":
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
+ startup_timer.record("install xformers")
+
if not is_installed("ngrok") and args.ngrok:
run_pip("install ngrok", "ngrok")
+ startup_timer.record("install ngrok")
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
@@ -348,22 +363,28 @@ def prepare_environment():
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
+ startup_timer.record("clone repositores")
+
if not is_installed("lpips"):
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
+ startup_timer.record("install CodeFormer requirements")
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
if not requirements_met(requirements_file):
run_pip(f"install -r \"{requirements_file}\"", "requirements")
+ startup_timer.record("install requirements")
run_extensions_installers(settings_file=args.ui_settings_file)
if args.update_check:
version_check(commit)
+ startup_timer.record("check version")
if args.update_all_extensions:
git_pull_recursive(extensions_dir)
+ startup_timer.record("update extensions")
if "--exit" in sys.argv:
print("Exiting because of --exit argument")
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 9fc89dc6..695c5736 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -3,8 +3,31 @@ import open_clip
import torch
import transformers.utils.hub
+from modules import shared
-class DisableInitialization:
+
+class ReplaceHelper:
+ def __init__(self):
+ self.replaced = []
+
+ def replace(self, obj, field, func):
+ original = getattr(obj, field, None)
+ if original is None:
+ return None
+
+ self.replaced.append((obj, field, original))
+ setattr(obj, field, func)
+
+ return original
+
+ def restore(self):
+ for obj, field, original in self.replaced:
+ setattr(obj, field, original)
+
+ self.replaced.clear()
+
+
+class DisableInitialization(ReplaceHelper):
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
@@ -21,7 +44,7 @@ class DisableInitialization:
"""
def __init__(self, disable_clip=True):
- self.replaced = []
+ super().__init__()
self.disable_clip = disable_clip
def replace(self, obj, field, func):
@@ -86,8 +109,81 @@ class DisableInitialization:
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
- for obj, field, original in self.replaced:
- setattr(obj, field, original)
+ self.restore()
- self.replaced.clear()
+class InitializeOnMeta(ReplaceHelper):
+ """
+ Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
+ which results in those parameters having no values and taking no memory. model.to() will be broken and
+ will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
+
+ Usage:
+ ```
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+ ```
+ """
+
+ def __enter__(self):
+ if shared.cmd_opts.disable_model_loading_ram_optimization:
+ return
+
+ def set_device(x):
+ x["device"] = "meta"
+ return x
+
+ linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
+ conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
+ mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
+ self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.restore()
+
+
+class LoadStateDictOnMeta(ReplaceHelper):
+ """
+ Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
+ As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
+ Meant to be used together with InitializeOnMeta above.
+
+ Usage:
+ ```
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
+ model.load_state_dict(state_dict, strict=False)
+ ```
+ """
+
+ def __init__(self, state_dict, device):
+ super().__init__()
+ self.state_dict = state_dict
+ self.device = device
+
+ def __enter__(self):
+ if shared.cmd_opts.disable_model_loading_ram_optimization:
+ return
+
+ sd = self.state_dict
+ device = self.device
+
+ def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
+ params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
+
+ for name, param in params:
+ if param.is_meta:
+ self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
+
+ original(self, state_dict, prefix, *args, **kwargs)
+
+ for name, _ in params:
+ key = prefix + name
+ if key in sd:
+ del sd[key]
+
+ linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
+ conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
+ mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.restore()
diff --git a/modules/sd_models.py b/modules/sd_models.py
index fb31a793..acb1e817 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -460,7 +460,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
-
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
- sd_model = instantiate_from_config(sd_config.model)
- except Exception:
- pass
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+
+ except Exception as e:
+ errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
- sd_model = instantiate_from_config(sd_config.model)
+
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
diff --git a/modules/timer.py b/modules/timer.py
index da99e49f..1d38595c 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -1,4 +1,5 @@
import time
+import argparse
class TimerSubcategory:
@@ -11,20 +12,27 @@ class TimerSubcategory:
def __enter__(self):
self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/"
+ self.timer.subcategory_level += 1
+
+ if self.timer.print_log:
+ print(f"{' ' * self.timer.subcategory_level}{self.category}:")
def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
- self.timer.record(self.category)
+ self.timer.subcategory_level -= 1
+ self.timer.record(self.category, disable_log=True)
class Timer:
- def __init__(self):
+ def __init__(self, print_log=False):
self.start = time.time()
self.records = {}
self.total = 0
self.base_category = ''
+ self.print_log = print_log
+ self.subcategory_level = 0
def elapsed(self):
end = time.time()
@@ -38,13 +46,16 @@ class Timer:
self.records[category] += amount
- def record(self, category, extra_time=0):
+ def record(self, category, extra_time=0, disable_log=False):
e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time
+ if self.print_log and not disable_log:
+ print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
+
def subcategory(self, name):
self.elapsed()
@@ -71,6 +82,10 @@ class Timer:
self.__init__()
-startup_timer = Timer()
+parser = argparse.ArgumentParser(add_help=False)
+parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
+args = parser.parse_known_args()[0]
+
+startup_timer = Timer(print_log=args.log_startup)
startup_record = None
diff --git a/webui.py b/webui.py
index 6bf06854..2dc4f1aa 100644
--- a/webui.py
+++ b/webui.py
@@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False):
if modules.sd_hijack.current_optimizer is None:
modules.sd_hijack.apply_optimizations()
- Thread(target=load_model).start()
+ devices.first_time_calculation()
- Thread(target=devices.first_time_calculation).start()
+ Thread(target=load_model).start()
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")