aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index acb1e817..cb67e425 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -52,6 +52,7 @@ class CheckpointInfo:
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
@@ -81,6 +82,7 @@ class CheckpointInfo:
checkpoints_list.pop(self.title)
self.title = f'{self.name} [{self.shorthash}]'
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
self.register()
return self.shorthash
@@ -101,14 +103,8 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
-
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
-
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
def list_models():
@@ -131,11 +127,14 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
- for filename in sorted(model_list, key=str.lower):
+ for filename in model_list:
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
@@ -145,6 +144,11 @@ def get_closet_checkpoint_match(search_string):
if found:
return found[0]
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+
return None