aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_models.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 80addf03..0164cc1b 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -45,7 +45,7 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
- model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
+ model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
@@ -180,7 +180,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
- pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
+ if checkpoint_file.endswith(".safetensors"):
+ try:
+ from safetensors.torch import load_file
+ except ImportError as e:
+ raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
+ pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
+ else:
+ pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")