aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py89
1 files changed, 65 insertions, 24 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 36f643e1..f65f4e36 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,10 +14,10 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -87,8 +87,7 @@ class CheckpointInfo:
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
@@ -96,10 +95,8 @@ except Exception:
def setup_model():
- if not os.path.exists(model_path):
- os.makedirs(model_path)
+ os.makedirs(model_path, exist_ok=True)
- list_models()
enable_midas_autodownload()
@@ -166,21 +163,22 @@ def model_hash(filename):
def select_checkpoint():
+ """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
-
+
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
if len(checkpoints_list) == 0:
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
if shared.cmd_opts.ckpt is not None:
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {model_path}", file=sys.stderr)
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
+ error_message += f"\n - directory {model_path}"
if shared.cmd_opts.ckpt_dir is not None:
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
- exit(1)
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
+ raise FileNotFoundError(error_message)
checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
@@ -239,7 +237,7 @@ def read_metadata_from_safetensors(filename):
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
- except Exception as e:
+ except Exception:
pass
return res
@@ -249,7 +247,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+ if not shared.opts.disable_mmap_load_safetensors:
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -315,8 +318,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
timer.record("apply half()")
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
devices.dtype_unet = model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
@@ -374,7 +375,7 @@ def enable_midas_autodownload():
if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
-
+
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
@@ -410,15 +411,22 @@ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_w
class SdModelData:
def __init__(self):
self.sd_model = None
+ self.was_loaded_at_least_once = False
self.lock = threading.Lock()
def get_sd_model(self):
+ if self.was_loaded_at_least_once:
+ return self.sd_model
+
if self.sd_model is None:
with self.lock:
+ if self.sd_model is not None or self.was_loaded_at_least_once:
+ return self.sd_model
+
try:
load_model()
except Exception as e:
- errors.display(e, "loading stable diffusion model")
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
@@ -467,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
+ except Exception:
pass
if sd_model is None:
@@ -493,6 +501,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model.eval()
model_data.sd_model = sd_model
+ model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -502,6 +511,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
+ with devices.autocast(), torch.no_grad():
+ sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+
+ timer.record("calculate empty prompt")
+
print(f"Model loaded in {timer.summary()}.")
return sd_model
@@ -521,6 +535,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ sd_unet.apply_unet("None")
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
@@ -538,13 +554,12 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
- checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
+ except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
@@ -565,7 +580,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
@@ -580,3 +595,29 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, token_merging_ratio):
+ """
+ Applies speed and memory optimizations from tomesd.
+ """
+
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+
+ if current_token_merging_ratio == token_merging_ratio:
+ return
+
+ if current_token_merging_ratio > 0:
+ tomesd.remove_patch(sd_model)
+
+ if token_merging_ratio > 0:
+ tomesd.apply_patch(
+ sd_model,
+ ratio=token_merging_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+
+ sd_model.applied_token_merged_ratio = token_merging_ratio