aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py46
1 files changed, 28 insertions, 18 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index fb31a793..1d93d893 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -33,6 +33,8 @@ class CheckpointInfo:
self.filename = filename
abspath = os.path.abspath(filename)
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
@@ -43,6 +45,19 @@ class CheckpointInfo:
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
+ def read_metadata():
+ metadata = read_metadata_from_safetensors(filename)
+ self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
+
+ return metadata
+
+ self.metadata = {}
+ if self.is_safetensors:
+ try:
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
+ except Exception as e:
+ errors.display(e, f"reading metadata for {filename}")
+
self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
@@ -53,16 +68,7 @@ class CheckpointInfo:
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
-
- self.metadata = {}
-
- _, ext = os.path.splitext(self.filename)
- if ext.lower() == ".safetensors":
- try:
- self.metadata = read_metadata_from_safetensors(filename)
- except Exception as e:
- errors.display(e, f"reading checkpoint metadata: {filename}")
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -79,7 +85,7 @@ class CheckpointInfo:
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
- checkpoints_list.pop(self.title)
+ checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
self.register()
@@ -460,7 +466,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
-
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -495,19 +500,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
- sd_model = instantiate_from_config(sd_config.model)
- except Exception:
- pass
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+
+ except Exception as e:
+ errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
- sd_model = instantiate_from_config(sd_config.model)
+
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)