From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:28:12 +0300 Subject: remove the need to place configs near models --- modules/sd_models.py | 228 +++++++++++++++++++++++++-------------------------- 1 file changed, 113 insertions(+), 115 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..fa208728 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path import sys import gc -import time -from collections import namedtuple import torch import re import safetensors.torch @@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path -from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting -from modules.sd_hijack_ip2p import should_hijack_ip2p +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) -def find_checkpoint_config(info): - if info is None: - return shared.cmd_opts.config - - config = os.path.splitext(info.filename)[0] + ".yaml" - if os.path.exists(config): - return config - - return shared.cmd_opts.config - - def list_models(): checkpoints_list.clear() checkpoint_alisases.clear() @@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location - if device is None: - device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) @@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo): +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + if checkpoint_info.title != title: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if cache_enabled and checkpoint_info in checkpoints_loaded: - # use checkpoint cache - print(f"Loading weights [{sd_model_hash}] from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - else: - # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") - sd = read_state_dict(checkpoint_info.filename) - model.load_state_dict(sd, strict=False) - del sd - - if cache_enabled: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) - if not shared.cmd_opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model + timer.record("apply half()") - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype_unet = model.model.diffusion_model.dtype - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename @@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") def enable_midas_autodownload(): @@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper -class Timer: - def __init__(self): - self.start = time.time() +def repair_config(sd_config): - def elapsed(self): - end = time.time() - res = end - self.start - self.start = end - return res + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True -def load_model(checkpoint_info=None): + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - checkpoint_config = find_checkpoint_config(checkpoint_info) - - if checkpoint_config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_config) - - if should_hijack_inpainting(checkpoint_info): - # Hardcoded config for now... - sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.unet_config.params.in_channels = 9 - sd_config.model.params.finetune_keys = None - - if should_hijack_ip2p(checkpoint_info): - sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.first_stage_key = "edited" - sd_config.model.params.cond_stage_key = "edit" - sd_config.model.params.image_size = 16 - sd_config.model.params.unet_config.params.in_channels = 8 - sd_config.model.params.unet_config.params.out_channels = 4 + do_inpainting_hijack() - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False + timer = Timer() - do_inpainting_hijack() + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - timer = Timer() + timer.record("find config") - sd_model = None + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) @@ -407,29 +392,35 @@ def load_model(checkpoint_info=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) - elapsed_create = timer.elapsed() + sd_model.used_config = checkpoint_config - load_model_weights(sd_model, checkpoint_info) + timer.record("create model") - elapsed_load_weights = timer.elapsed() + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) + timer.record("move model to device") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("load textual inversion embeddings") + script_callbacks.model_loaded_callback(sd_model) - elapsed_the_rest = timer.elapsed() + timer.record("scripts callbacks") - print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - checkpoint_config = find_checkpoint_config(current_checkpoint_info) - - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): - del sd_model - checkpoints_loaded.clear() - load_model(checkpoint_info) - return shared.sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None): timer = Timer() + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + try: - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception as e: print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info) + load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) + timer.record("move model to device") - elapsed = timer.elapsed() - - print(f"Weights loaded in {elapsed:.1f}s.") + print(f"Weights loaded in {timer.summary()}.") return sd_model -- cgit v1.2.1 From 6f31d2210c189f8db118e6f95add7ba2a64f0238 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:54:19 +0300 Subject: support detecting midas model fix broken api for checkpoint list --- modules/sd_models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index fa208728..37dad18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -439,12 +439,12 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(sd_model) + sd_hijack.model_hijack.undo_hijack(sd_model) timer = Timer() -- cgit v1.2.1