aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 9decc911..7a5edced 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -23,6 +23,10 @@ except Exception:
pass
+def checkpoint_tiles():
+ return sorted([x.title for x in checkpoints_list.values()])
+
+
def list_models():
checkpoints_list.clear()
@@ -39,13 +43,14 @@ def list_models():
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
- return f'{name} [{h}]'
+ shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+
+ return f'{name} [{h}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
- title = modeltitle(cmd_ckpt, h)
- model_name = title.rsplit(".",1)[0] # remove extension if present
+ title, model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
@@ -53,8 +58,7 @@ def list_models():
if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename)
- title = modeltitle(filename, h)
- model_name = title.rsplit(".",1)[0] # remove extension if present
+ title, model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)