aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6ff5d17d..060e0007 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -23,7 +23,8 @@ model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
checkpoints_list = {}
-checkpoint_alisases = {}
+checkpoint_aliases = {}
+checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
@@ -66,7 +67,7 @@ class CheckpointInfo:
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
- checkpoint_alisases[id] = self
+ checkpoint_aliases[id] = self
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
@@ -112,7 +113,7 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
- checkpoint_alisases.clear()
+ checkpoint_aliases.clear()
cmd_ckpt = shared.cmd_opts.ckpt
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
@@ -136,7 +137,7 @@ def list_models():
def get_closet_checkpoint_match(search_string):
- checkpoint_info = checkpoint_alisases.get(search_string, None)
+ checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -166,7 +167,7 @@ def select_checkpoint():
"""Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+ if not shared.opts.disable_mmap_load_safetensors:
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -585,7 +591,6 @@ def unload_model_weights(sd_model=None, info=None):
sd_model = None
gc.collect()
devices.torch_gc()
- torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.")