From 4b43480fe8b65a3bd24dc9bc03a7e910c9b0314f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 07:08:11 +0300 Subject: show metadata for SD checkpoints in the extra networks UI --- modules/sd_models.py | 26 ++++++++++++++++---------- modules/ui_extra_networks_checkpoints.py | 1 + 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..1af7fd78 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -33,6 +33,8 @@ class CheckpointInfo: self.filename = filename abspath = os.path.abspath(filename) + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): name = abspath.replace(shared.cmd_opts.ckpt_dir, '') elif abspath.startswith(model_path): @@ -43,6 +45,19 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"): name = name[1:] + def read_metadata(): + metadata = read_metadata_from_safetensors(filename) + self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None) + + return metadata + + self.metadata = {} + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading metadata for {filename}") + self.name = name self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] @@ -55,15 +70,6 @@ class CheckpointInfo: self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) - self.metadata = {} - - _, ext = os.path.splitext(self.filename) - if ext.lower() == ".safetensors": - try: - self.metadata = read_metadata_from_safetensors(filename) - except Exception as e: - errors.display(e, f"reading checkpoint metadata: {filename}") - def register(self): checkpoints_list[self.title] = self for id in self.ids: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 76780cfd..2bb0a222 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -23,6 +23,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": checkpoint.metadata, "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, } -- cgit v1.2.1