aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index dc81b0dc..9decc911 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -10,7 +10,7 @@ from ldm.util import instantiate_from_config
from modules import shared
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
try:
@@ -45,7 +45,8 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
+ model_name = title.rsplit(".",1)[0] # remove extension if present
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
@@ -53,7 +54,8 @@ def list_models():
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename)
title = modeltitle(filename, h)
- checkpoints_list[title] = CheckpointInfo(filename, title, h)
+ model_name = title.rsplit(".",1)[0] # remove extension if present
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
def model_hash(filename):