From c1928cdd6194928af0f53f70c51d59479b7025e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 18:58:08 +0300 Subject: bring back short hashes to sd checkpoint selection --- modules/sd_models.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6a681cef..12083848 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,14 +41,16 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"): name = name[1:] - self.title = name + self.name = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) self.shorthash = self.sha256[0:10] if self.sha256 else None - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -56,13 +58,15 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256] self.register() + self.title = f'{self.name} [{self.shorthash}]' + return self.shorthash @@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None def load_model_weights(model, checkpoint_info: CheckpointInfo): + title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + if checkpoint_info.title != title: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title cache_enabled = shared.opts.sd_checkpoint_cache > 0 -- cgit v1.2.1