From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 17:53:56 +0200 Subject: ui fix, re organization of the code --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..8e4ee435 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging + from transformers import logging, CLIPModel logging.set_verbosity_error() except Exception: @@ -196,6 +196,9 @@ def load_model(): sd_hijack.model_hijack.hijack(sd_model) + if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: + shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) + sd_model.eval() print(f"Model loaded.") -- cgit v1.2.1 From 8e7097d06a6a261580d34375c9d2a9e4ffc63ffa Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 13:47:45 -0700 Subject: Added support for RunwayML inpainting model --- modules/sd_models.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index eae22e87..47836d25 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices from modules.paths import models_path +from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -211,6 +212,19 @@ def load_model(): print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) + + if should_hijack_inpainting(checkpoint_info): + do_inpainting_hijack() + + # Hardcoded config for now... + sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + sd_config.model.params.use_ema = False + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.unet_config.params.in_channels = 9 + + # Create a "fake" config with a different name so that we know to unload it when switching models. + checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -234,7 +248,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model -- cgit v1.2.1 From 708c3a7bd8ce68cbe1aa7c268e5a4b1980affc9f Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 13:28:43 -0700 Subject: Added PLMS hijack and made sure to always replace methods --- modules/sd_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 47836d25..7072db08 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -214,8 +214,6 @@ def load_model(): sd_config = OmegaConf.load(checkpoint_info.config) if should_hijack_inpainting(checkpoint_info): - do_inpainting_hijack() - # Hardcoded config for now... sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" sd_config.model.params.use_ema = False @@ -225,6 +223,7 @@ def load_model(): # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.1 From 49533eed9e3aad19e9868ee140708baec4fd44be Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 16:01:27 -0700 Subject: XY grid correctly re-assignes model when config changes --- modules/sd_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072db08..fea84630 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack - checkpoint_info = select_checkpoint() + checkpoint_info = checkpoint_info or select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") @@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model() + shared.sd_model = load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: -- cgit v1.2.1 From df5706409386cc2e88718bd9101045587c39f8bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 16:10:51 +0300 Subject: do not load aesthetic clip model until it's needed add refresh button for aesthetic embeddings add aesthetic params to images' infotext --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 05a1df28..b1c91b0d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -234,9 +234,6 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) - if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: - shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) - sd_model.eval() print(f"Model loaded.") -- cgit v1.2.1 From ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 17:35:51 +0300 Subject: loading SD VAE, see PR #3303 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b1c91b0d..d99dbce8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.1 From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 12:23:45 +0300 Subject: removed aesthetic gradients as built-in added support for extensions --- modules/sd_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index d99dbce8..f9b3063d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices +from modules import shared, modelloader, devices, script_callbacks from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -238,6 +238,9 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) sd_model.eval() + shared.sd_model = sd_model + + script_callbacks.model_loaded_callback(sd_model) print(f"Model loaded.") return sd_model @@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model(checkpoint_info) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: -- cgit v1.2.1 From 0df94d3fcf9d1fc47c4d39039352a3d5b3380c1f Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 22 Oct 2022 12:59:21 -0400 Subject: fix aesthetic gradients doing nothing after loading a different model --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f9b3063d..49dc3238 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -236,12 +236,11 @@ def load_model(checkpoint_info=None): sd_model.to(shared.device) sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) sd_model.eval() shared.sd_model = sd_model - script_callbacks.model_loaded_callback(sd_model) - print(f"Model loaded.") return sd_model @@ -268,6 +267,7 @@ def reload_model_weights(sd_model, info=None): load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) -- cgit v1.2.1 From 321bacc6a9eaf4a25f31279f288fa752be507a20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:15:12 +0300 Subject: call model_loaded_callback after setting shared.sd_model in case scripts refer to it using that --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 49dc3238..e697bb72 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -236,11 +236,12 @@ def load_model(checkpoint_info=None): sd_model.to(shared.device) sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) sd_model.eval() shared.sd_model = sd_model + script_callbacks.model_loaded_callback(sd_model) + print(f"Model loaded.") return sd_model -- cgit v1.2.1 From b50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9 Mon Sep 17 00:00:00 2001 From: Josh Watzman Date: Thu, 27 Oct 2022 21:59:16 +0100 Subject: Reduce peak memory usage when changing models A few tweaks to reduce peak memory usage, the biggest being that if we aren't using the checkpoint cache, we shouldn't duplicate the model state dict just to immediately throw it away. On my machine with 16GB of RAM, this change means I can typically change models, whereas before it would typically OOM. --- modules/sd_models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e697bb72..203e99a8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info): print(f"Global Step: {pl_sd['global_step']}") sd = get_state_dict_from_checkpoint(pl_sd) - missing, extra = model.load_state_dict(sd, strict=False) + del pl_sd + model.load_state_dict(sd, strict=False) + del sd if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -194,9 +196,10 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.to(devices.dtype_vae) - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) # LRU + if shared.opts.sd_checkpoint_cache > 0: + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") checkpoints_loaded.move_to_end(checkpoint_info) -- cgit v1.2.1 From 5d5dc64064d8ca399a76fe44dbb62bdef6c4b7c4 Mon Sep 17 00:00:00 2001 From: Antonio Date: Fri, 28 Oct 2022 05:49:39 +0200 Subject: Natural sorting for dropdown checkpoint list Example: Before After 11.ckpt 11.ckpt ab.ckpt ab.ckpt ade_pablo_step_1000.ckpt ade_pablo_step_500.ckpt ade_pablo_step_500.ckpt ade_pablo_step_1000.ckpt ade_step_1000.ckpt ade_step_500.ckpt ade_step_1500.ckpt ade_step_1000.ckpt ade_step_2000.ckpt ade_step_1500.ckpt ade_step_2500.ckpt ade_step_2000.ckpt ade_step_3000.ckpt ade_step_2500.ckpt ade_step_500.ckpt ade_step_3000.ckpt atp_step_5500.ckpt atp_step_5500.ckpt model1.ckpt model1.ckpt model10.ckpt model10.ckpt model1000.ckpt model33.ckpt model33.ckpt model50.ckpt model400.ckpt model400.ckpt model50.ckpt model1000.ckpt moo44.ckpt moo44.ckpt v1-4-pruned-emaonly.ckpt v1-4-pruned-emaonly.ckpt v1-5-pruned-emaonly.ckpt v1-5-pruned-emaonly.ckpt v1-5-pruned.ckpt v1-5-pruned.ckpt v1-5-vae.ckpt v1-5-vae.ckpt --- modules/sd_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e697bb72..64d5ee0d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -3,6 +3,7 @@ import os.path import sys from collections import namedtuple import torch +import re from omegaconf import OmegaConf from ldm.util import instantiate_from_config @@ -35,8 +36,10 @@ def setup_model(): list_models() -def checkpoint_tiles(): - return sorted([x.title for x in checkpoints_list.values()]) +def checkpoint_tiles(): + convert = lambda name: int(name) if name.isdigit() else name.lower() + alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] + return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) def list_models(): -- cgit v1.2.1 From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 21:54:31 +0700 Subject: Settings to select VAE --- modules/sd_models.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..91ad4b5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks +from modules import shared, modelloader, devices, script_callbacks, sd_vae from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -160,12 +160,11 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - -def load_model_weights(model, checkpoint_info): +def load_model_weights(model, checkpoint_info, force=False): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if checkpoint_info not in checkpoints_loaded: + if force or checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -186,17 +185,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path - - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict) - + sd_vae.load_vae(model, checkpoint_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: @@ -213,7 +202,7 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None): +def load_model(checkpoint_info=None, force=False): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -234,7 +223,7 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -252,16 +241,16 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model, info=None): +def reload_model_weights(sd_model, info=None, force=False): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info) + load_model(checkpoint_info, force=force) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -271,7 +260,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) -- cgit v1.2.1 From 726769da35970f4c100fa7edf11850f9dc059c41 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:19:34 +0700 Subject: Checkpoint cache by combination key of checkpoint and vae --- modules/sd_models.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 91ad4b5e..850f7b7b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -160,11 +160,15 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} -def load_model_weights(model, checkpoint_info, force=False): +def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if force or checkpoint_info not in checkpoints_loaded: + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + + checkpoint_key = (checkpoint_info, vae_file) + + if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -185,24 +189,25 @@ def load_model_weights(model, checkpoint_info, force=False): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, checkpoint_file) + sd_vae.load_vae(model, vae_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU else: - print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) - model.load_state_dict(checkpoints_loaded[checkpoint_info]) + vae_name = sd_vae.get_filename(vae_file) + print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") + checkpoints_loaded.move_to_end(checkpoint_key) + model.load_state_dict(checkpoints_loaded[checkpoint_key]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None, force=False): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -223,7 +228,7 @@ def load_model(checkpoint_info=None, force=False): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -250,7 +255,7 @@ def reload_model_weights(sd_model, info=None, force=False): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info, force=force) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -260,7 +265,7 @@ def reload_model_weights(sd_model, info=None, force=False): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) -- cgit v1.2.1 From 36966e3200943dbf890b5338cfa939df552d3c47 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:38:58 +0700 Subject: Fix #4035 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..a29c8c1a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -201,7 +201,7 @@ def load_model_weights(model, checkpoint_info): if shared.opts.sd_checkpoint_cache > 0: checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") -- cgit v1.2.1 From bf7a699845675eefdabb9cfa40c55398976274ae Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 16:27:27 +0700 Subject: Fix #4035 for real now --- modules/sd_models.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a29c8c1a..b2dd005a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,6 +165,9 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash + if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -198,16 +201,14 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) - - if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: - checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) model.load_state_dict(checkpoints_loaded[checkpoint_info]) + if shared.opts.sd_checkpoint_cache > 0: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -- cgit v1.2.1 From af758e97fa2c4c853042f121af4e974be01e6696 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 1 Nov 2022 04:01:49 -0300 Subject: Unload sd_model before loading the other --- modules/sd_models.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..90007da3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,6 +1,7 @@ import collections import os.path import sys +import gc from collections import namedtuple import torch import re @@ -220,6 +221,12 @@ def load_model(checkpoint_info=None): if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") + if shared.sd_model: + sd_hijack.model_hijack.undo_hijack(shared.sd_model) + shared.sd_model = None + gc.collect() + devices.torch_gc() + sd_config = OmegaConf.load(checkpoint_info.config) if should_hijack_inpainting(checkpoint_info): @@ -233,6 +240,7 @@ def load_model(checkpoint_info=None): checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) do_inpainting_hijack() + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -252,14 +260,18 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model, info=None): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() + if not sd_model: + sd_model = shared.sd_model + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) return shared.sd_model -- cgit v1.2.1 From 056f06d3738c267b1014e6e8e1ef5bd97af1fb45 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 12:51:46 +0700 Subject: Reload VAE without reloading sd checkpoint --- modules/sd_models.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ab85b65..883639d1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd -vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - checkpoint_key = (checkpoint_info, vae_file) + checkpoint_key = checkpoint_info if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, vae_file) - model.first_stage_model.to(devices.dtype_vae) - if shared.opts.sd_checkpoint_cache > 0: + # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU + else: vae_name = sd_vae.get_filename(vae_file) print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") @@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.load_vae(model, vae_file) + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack @@ -254,14 +253,14 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model=None, info=None, force=False): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if not sd_model: sd_model = shared.sd_model - if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): -- cgit v1.2.1 From f2a5cbe6f55592c4c5527b8e0bf99ea8d658f057 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 14:41:29 +0300 Subject: fix #3986 breaking --no-half-vae --- modules/sd_models.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 883639d1..5075fadb 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.to(memory_format=torch.channels_last) if not shared.cmd_opts.no_half: + vae = model.first_stage_model + + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + model.half() + model.first_stage_model = vae devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + model.first_stage_model.to(devices.dtype_vae) + if shared.opts.sd_checkpoint_cache > 0: # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() -- cgit v1.2.1 From 3780ad3ad837dd406da39eebd5d91009b5a58445 Mon Sep 17 00:00:00 2001 From: digburn Date: Fri, 4 Nov 2022 00:40:21 +0000 Subject: fix: loading models without vae from cache --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5075fadb..ae427a5c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,8 +204,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU else: - vae_name = sd_vae.get_filename(vae_file) - print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") + vae_name = sd_vae.get_filename(vae_file) if vae_file else None + vae_message = f" with {vae_name} VAE" if vae_name else "" + print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") checkpoints_loaded.move_to_end(checkpoint_key) model.load_state_dict(checkpoints_loaded[checkpoint_key]) -- cgit v1.2.1 From 99043f33606d3057f83ea52a403e10cd29d1f7e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 11:20:42 +0300 Subject: fix one of previous merges breaking the program --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 63e07a12..34c57bfa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -167,6 +167,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): sd_vae.restore_base_vae(model) checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") -- cgit v1.2.1 From 3b51d239ac9201228c6032fc109111e347e8e6b0 Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 04:54:21 +0100 Subject: - do not use ckpt cache, if disabled - cache model after is has been loaded from file --- modules/sd_models.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 34c57bfa..720c2a96 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + cache_enabled = shared.opts.sd_checkpoint_cache > 0 + + if cache_enabled: sd_vae.restore_base_vae(model) - checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if checkpoint_info not in checkpoints_loaded: + if cache_enabled and checkpoint_info in checkpoints_loaded: + # use checkpoint cache + vae_name = sd_vae.get_filename(vae_file) if vae_file else None + vae_message = f" with {vae_name} VAE" if vae_name else "" + print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + model.load_state_dict(checkpoints_loaded[checkpoint_info]) + else: + # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): del pl_sd model.load_state_dict(sd, strict=False) del sd + + if cache_enabled: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -199,13 +211,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.first_stage_model.to(devices.dtype_vae) - else: - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - - if shared.opts.sd_checkpoint_cache > 0: + # clean up cache if limit is reached + if cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU -- cgit v1.2.1 From eebf49592ad2c0933e58b06a098b92e48d47e4fe Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 07:17:09 +0100 Subject: restore #4035 behavior - if checkpoint cache is set to 1, keep 2 models in cache (current +1 more) --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 720c2a96..80addf03 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -213,7 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # clean up cache if limit is reached if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash -- cgit v1.2.1 From abc1e79a5da24a1ea0f4bceedcdf225f32010aa8 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 11:10:53 +0700 Subject: Fix base VAE caching was done after loading VAE, also add safeguard --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..e4dba62c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) -- cgit v1.2.1 From c7be83bf0240498d9382e2afeaa3f0677d26c7f6 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 11:11:14 +0700 Subject: Misc Misc --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e4dba62c..cd7fe37a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) -- cgit v1.2.1 From 2c5ca706a7e624d268545ba3318ba230b7b33477 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 10:55:47 +0700 Subject: Remove no longer necessary parts and add vae_file safeguard --- modules/sd_models.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..c59151e0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): cache_enabled = shared.opts.sd_checkpoint_cache > 0 - if cache_enabled: - sd_vae.restore_base_vae(model) - - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if cache_enabled and checkpoint_info in checkpoints_loaded: # use checkpoint cache - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + print(f"Loading weights [{sd_model_hash}] from cache") model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file @@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) sd_vae.load_vae(model, vae_file) -- cgit v1.2.1 From 0efffbb407a9d07eae6850374099775385ce176c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 21 Nov 2022 14:04:25 +0100 Subject: Supporting `*.safetensors` format. If a model file exists with extension `.safetensors` then we can load it more safely than with PyTorch weights. --- modules/sd_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..0164cc1b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -45,7 +45,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -180,7 +180,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if checkpoint_file.endswith(".safetensors"): + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") + pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") -- cgit v1.2.1 From 1e506657e1cb732a5f0e567ba2585fba2bbb1327 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 26 Nov 2022 13:28:44 -0500 Subject: no-half support for SD 2.0 --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index c59151e0..0e0bd79e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -244,6 +244,9 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.1 From 6074175faa751dde933aa8e15cd687ca4e4b4a23 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 14:46:40 +0300 Subject: add safetensors to requirements --- modules/sd_models.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ae36841a..77236480 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -5,6 +5,7 @@ import gc from collections import namedtuple import torch import re +import safetensors.torch from omegaconf import OmegaConf from ldm.util import instantiate_from_config @@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if checkpoint_file.endswith(".safetensors"): - try: - from safetensors.torch import load_file - except ImportError as e: - raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") - pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location) else: pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") -- cgit v1.2.1 From dac9b6f15de5e675053d9490a20e0457dcd1a23e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 15:51:29 +0300 Subject: add safetensors support for model merging #4869 --- modules/sd_models.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 77236480..a1ea5611 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -160,6 +160,20 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) + + if print_global_state and "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + + sd = get_state_dict_from_checkpoint(pl_sd) + return sd + + def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -174,17 +188,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - _, extension = os.path.splitext(checkpoint_file) - if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location) - else: - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) - - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - - sd = get_state_dict_from_checkpoint(pl_sd) - del pl_sd + sd = read_state_dict(checkpoint_file) model.load_state_dict(sd, strict=False) del sd -- cgit v1.2.1 From 0376da180c81a11880a2587903d69d85541051e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 28 Nov 2022 08:39:59 +0300 Subject: make it possible to save nai model using safetensors --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a1ea5611..283cf1cd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -144,8 +144,8 @@ def transform_checkpoint_dict_key(k): def get_state_dict_from_checkpoint(pl_sd): - if "state_dict" in pl_sd: - pl_sd = pl_sd["state_dict"] + pl_sd = pl_sd.pop("state_dict", pl_sd) + pl_sd.pop("state_dict", None) sd = {} for k, v in pl_sd.items(): -- cgit v1.2.1 From 1ed4f0e22807f3afef925210182cbbee51f0cb2c Mon Sep 17 00:00:00 2001 From: Jay Smith Date: Thu, 8 Dec 2022 18:14:35 -0600 Subject: Depth2img model support --- modules/sd_models.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 283cf1cd..139952ba 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,6 +7,9 @@ import torch import re import safetensors.torch from omegaconf import OmegaConf +from os import mkdir +from urllib import request +import ldm.modules.midas as midas from ldm.util import instantiate_from_config @@ -36,6 +39,7 @@ def setup_model(): os.makedirs(model_path) list_models() + enable_midas_autodownload() def checkpoint_tiles(): @@ -227,6 +231,48 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): sd_vae.load_vae(model, vae_file) +def enable_midas_autodownload(): + """ + Gives the ldm.modules.midas.api.load_model function automatic downloading. + + When the 512-depth-ema model, and other future models like it, is loaded, + it calls midas.api.load_model to load the associated midas depth model. + This function applies a wrapper to download the model to the correct + location automatically. + """ + + midas_path = os.path.join(models_path, 'midas') + + # stable-diffusion-stability-ai hard-codes the midas model path to + # a location that differs from where other scripts using this model look. + # HACK: Overriding the path here. + for k, v in midas.api.ISL_PATHS.items(): + file_name = os.path.basename(v) + midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name) + + midas_urls = { + "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt", + "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt", + } + + midas.api.load_model_inner = midas.api.load_model + + def load_model_wrapper(model_type): + path = midas.api.ISL_PATHS[model_type] + if not os.path.exists(path): + if not os.path.exists(midas_path): + mkdir(midas_path) + + print(f"Downloading midas model weights for {model_type} to {path}") + request.urlretrieve(midas_urls[model_type], path) + print(f"{model_type} downloaded") + + return midas.api.load_model_inner(model_type) + + midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() -- cgit v1.2.1 From bd81a09eacf02dad095b98094ab936f276d0343f Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 10 Dec 2022 11:29:26 -0500 Subject: fix support for 2.0 inpainting model while maintaining support for 1.5 inpainting model --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5b37f3fe..b64f573f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -296,6 +296,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.use_ema = False sd_config.model.params.conditioning_key = "hybrid" sd_config.model.params.unet_config.params.in_channels = 9 + sd_config.model.params.finetune_keys = None # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) -- cgit v1.2.1 From 59c6511494c55a578eecdf71fb4590b6bd5d04a7 Mon Sep 17 00:00:00 2001 From: Dean van Dugteren <31391056+deanpress@users.noreply.github.com> Date: Sun, 11 Dec 2022 17:08:51 +0100 Subject: fix: fallback model_checkpoint if it's empty This fixes the following error when SD attempts to start with a deleted checkpoint: ``` Traceback (most recent call last): File "D:\Web\stable-diffusion-webui\launch.py", line 295, in start() File "D:\Web\stable-diffusion-webui\launch.py", line 290, in start webui.webui() File "D:\Web\stable-diffusion-webui\webui.py", line 132, in webui initialize() File "D:\Web\stable-diffusion-webui\webui.py", line 62, in initialize modules.sd_models.load_model() File "D:\Web\stable-diffusion-webui\modules\sd_models.py", line 283, in load_model checkpoint_info = checkpoint_info or select_checkpoint() File "D:\Web\stable-diffusion-webui\modules\sd_models.py", line 117, in select_checkpoint checkpoint_info = checkpoints_list.get(model_checkpoint, None) TypeError: unhashable type: 'list' ``` --- modules/sd_models.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5b37f3fe..b6d75db7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -111,6 +111,10 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint + + if len(model_checkpoint) == 0: + model_checkpoint = shared.default_sd_model_file + checkpoint_info = checkpoints_list.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info -- cgit v1.2.1 From ec0a48826fb41c1b1baab45a9030f7eb55568fd0 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sun, 11 Dec 2022 10:19:46 -0500 Subject: unconditionally set use_ema=False if value not specified (True never worked, and all configs except v1-inpainting-inference.yaml already correctly set it to False) --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b64f573f..f36b299f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -293,7 +293,6 @@ def load_model(checkpoint_info=None): if should_hijack_inpainting(checkpoint_info): # Hardcoded config for now... sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.use_ema = False sd_config.model.params.conditioning_key = "hybrid" sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None @@ -301,6 +300,9 @@ def load_model(checkpoint_info=None): # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + do_inpainting_hijack() if shared.cmd_opts.no_half: -- cgit v1.2.1 From 5a650055de3792223a91925aba8130ebdee29e35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?linuxmobile=20=28=20=E3=83=AA=E3=83=8A=E3=83=83=E3=82=AF?= =?UTF-8?q?=E3=82=B9=20=29?= Date: Sat, 24 Dec 2022 09:25:35 -0300 Subject: Removed lenght in sd_model at line 115 Commit eba60a4 is what is causing this error, delete the length check in sd_model starting at line 115 and it's fine. https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5971#issuecomment-1364507379 --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1254e5ae..6ca06211 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -111,9 +111,6 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - - if len(model_checkpoint) == 0: - model_checkpoint = shared.default_sd_model_file checkpoint_info = checkpoints_list.get(model_checkpoint, None) if checkpoint_info is not None: -- cgit v1.2.1 From 3bf5591efe9a9f219c6088be322a87adc4f48f95 Mon Sep 17 00:00:00 2001 From: Yuval Aboulafia Date: Sat, 24 Dec 2022 21:35:29 +0200 Subject: fix F541 f-string without any placeholders --- modules/sd_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ca06211..ecdd91c5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -117,13 +117,13 @@ def select_checkpoint(): return checkpoint_info if len(checkpoints_list) == 0: - print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) + print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) if shared.cmd_opts.ckpt is not None: print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) print(f" - directory {model_path}", file=sys.stderr) if shared.cmd_opts.ckpt_dir is not None: print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) - print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) + print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) checkpoint_info = next(iter(checkpoints_list.values())) @@ -324,7 +324,7 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) - print(f"Model loaded.") + print("Model loaded.") return sd_model @@ -359,5 +359,5 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print(f"Weights loaded.") + print("Weights loaded.") return sd_model -- cgit v1.2.1 From 5ba04f9ec050a66e918571f07e8863f157f05b44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Dec 2022 13:45:58 +0100 Subject: Attempting to solve slow loads for `safetensors`. Fixes #5893 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..cd938656 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -168,7 +168,10 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + device = map_location or shared.weight_load_location + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.1 From f55ac33d446185680604e872ceda2ae858821d5c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 31 Dec 2022 11:27:02 -0500 Subject: validate textual inversion embeddings --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..ebd4dff7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -325,6 +325,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model + return sd_model -- cgit v1.2.1 From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 00:38:09 +0300 Subject: fix the issue with training on SD2.0 --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ebd4dff7..bff8d6c9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + model.logvar = model.logvar.to(devices.device) # fix for training + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) -- cgit v1.2.1 From 8f96f9289981a66741ba770d14f3d27ce335a0fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:39:14 +0300 Subject: call script callbacks for reloaded model after loading embeddings --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index bff8d6c9..b98b05fc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -324,12 +324,12 @@ def load_model(checkpoint_info=None): sd_model.eval() shared.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model - return sd_model -- cgit v1.2.1 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/sd_models.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model -- cgit v1.2.1 From 8d8a05a3bbb50fdfeab51679a919d2487bd97976 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:47:42 +0300 Subject: find configs for models at runtime rather than when starting --- modules/sd_models.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6846b74a..6dca4ddf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() @@ -48,6 +48,14 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def find_checkpoint_config(info): + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return shared.cmd_opts.config + + def list_models(): checkpoints_list.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) @@ -73,7 +81,7 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -81,12 +89,7 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) - config = basename + ".yaml" - if not os.path.exists(config): - config = shared.cmd_opts.config - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) def get_closet_checkpoint_match(searchString): @@ -282,9 +285,10 @@ def enable_midas_autodownload(): def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + checkpoint_config = find_checkpoint_config(checkpoint_info) - if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_info.config}") + if checkpoint_config != shared.cmd_opts.config: + print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -292,7 +296,7 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_info.config) + sd_config = OmegaConf.load(checkpoint_config) if should_hijack_inpainting(checkpoint_info): # Hardcoded config for now... @@ -302,7 +306,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.finetune_keys = None # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None): sd_model = shared.sd_model current_checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_config = find_checkpoint_config(current_checkpoint_info) if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.1 From 0cd6399b8b1699b8b7acad6f0ad2988111fe618e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 14:29:13 +0300 Subject: fix broken inpainting model --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dca4ddf..a568823d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -305,9 +305,6 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None - # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) - if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False -- cgit v1.2.1 From 642142556d8ecdea9beb86d7618b628b1803ab98 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 15:09:53 +0300 Subject: use commandline-supplied cuda device name instead of cuda:0 for safetensors PR that doesn't fix anything --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ee918f24..76a89e88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None if extension.lower() == ".safetensors": device = map_location or shared.weight_load_location if device is None: - device = "cuda:0" if torch.cuda.is_available() else "cpu" + device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.1 From 552d7b90bf483c160cd20740f7acd7fccbc02e6f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 18:34:26 -0500 Subject: allow model load if previous model failed --- modules/sd_models.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 76a89e88..0a6d55ca 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -49,6 +49,9 @@ def checkpoint_tiles(): def find_checkpoint_config(info): + if info is None: + return shared.cmd_opts.config + config = os.path.splitext(info.filename)[0] + ".yaml" if os.path.exists(config): return config @@ -345,14 +348,16 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed + current_checkpoint_info = None + else: + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: + return - current_checkpoint_info = sd_model.sd_checkpoint_info checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return - - if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.1 From 0c3feb202c5714abd50d879c1db2cd9a71ce93e3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 14:08:29 +0300 Subject: disable torch weight initialization and CLIP downloading/reading checkpoint to speedup creating sd model from config --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a6d55ca..ee241032 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,7 +13,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -319,7 +319,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False - sd_model = instantiate_from_config(sd_config.model) + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.1 From ce3f639ec8758ce2bc90483336361d2dc25acd3a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 16:51:04 +0300 Subject: add more stuff to ignore when creating model from config prevent .vae.safetensors files from being listed as stable diffusion models --- modules/sd_models.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ee241032..1bb9088b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,7 @@ import collections import os.path import sys import gc +import time from collections import namedtuple import torch import re @@ -61,7 +62,7 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -288,6 +289,17 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper +class Timer: + def __init__(self): + self.start = time.time() + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -319,11 +331,17 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + timer = Timer() + with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) + elapsed_create = timer.elapsed() + load_model_weights(sd_model, checkpoint_info) + elapsed_load_weights = timer.elapsed() + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: @@ -338,7 +356,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) - print("Model loaded.") + elapsed_the_rest = timer.elapsed() + + print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") return sd_model @@ -349,7 +369,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model - if sd_model is None: # previous model load failed + if sd_model is None: # previous model load failed current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info @@ -371,6 +391,8 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) + timer = Timer() + try: load_model_weights(sd_model, checkpoint_info) except Exception as e: @@ -384,6 +406,8 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("Weights loaded.") + elapsed = timer.elapsed() + + print(f"Weights loaded in {elapsed:.1f}s.") return sd_model -- cgit v1.2.1 From 0f8603a55988d22616b17140e6c4a7e9d0736af5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 17:46:59 +0300 Subject: add support for transformers==4.25.1 add fallback for when quick model creation fails --- modules/sd_models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1bb9088b..b5bc12f0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -333,7 +333,11 @@ def load_model(checkpoint_info=None): timer = Timer() - with sd_disable_initialization.DisableInitialization(): + try: + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) + except Exception as e: + print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) elapsed_create = timer.elapsed() -- cgit v1.2.1 From 4fdacd31e48c6a7a35c1c25c559932585e8addde Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:24:56 +0300 Subject: possible fix for fallback for fast model creation from config --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b5bc12f0..a0a8a909 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -337,6 +337,9 @@ def load_model(checkpoint_info=None): with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) except Exception as e: + pass + + if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.1 From 1a23dc32ac5e16fac10115cafd0b841abd06e59f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:34:36 +0300 Subject: possible fix for fallback for fast model creation from config, attempt 2 --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a0a8a909..084ba7fa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -333,6 +333,7 @@ def load_model(checkpoint_info=None): timer = Timer() + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.1 From 4bd490727e156ff53107d53416d6b89be86f2a62 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 18:54:04 +0300 Subject: fix for an error caused by skipping initialization, for realsies this time: TypeError: expected str, bytes or os.PathLike object, not NoneType --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 084ba7fa..c466f273 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -334,6 +334,7 @@ def load_model(checkpoint_info=None): timer = Timer() sd_model = None + try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.1