From d132481058f8a827cd407f2121f128a2bb862f7a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 2 Apr 2023 17:41:55 -0500 Subject: Embed model merge metadata in .safetensors file --- modules/sd_models.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ea874df..4f7613a1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -52,6 +52,15 @@ class CheckpointInfo: self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + self.metadata = {} + + _, ext = os.path.splitext(self.filename) + if ext.lower() == ".safetensors": + try: + self.metadata = read_metadata_from_safetensors(filename) + except Exception as e: + errors.display(e, f"reading checkpoint metadata: {filename}") + def register(self): checkpoints_list[self.title] = self for id in self.ids: @@ -544,4 +553,4 @@ def unload_model_weights(sd_model=None, info=None): print(f"Unloaded weights {timer.summary()}.") - return sd_model \ No newline at end of file + return sd_model -- cgit v1.2.1 From b1717c0a4804f8ed3bb8cc2f3aea5d095778b447 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 2 May 2023 09:08:00 +0300 Subject: do not load wait for shared.sd_model to load at startup --- modules/sd_models.py | 54 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 14 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 4f7613a1..59adc7cc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,8 @@ import collections import os.path import sys import gc +import threading + import torch import re import safetensors.torch @@ -404,13 +406,39 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' -def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): + +class SdModelData: + def __init__(self): + self.sd_model = None + self.lock = threading.Lock() + + def get_sd_model(self): + if self.sd_model is None: + with self.lock: + try: + load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load", file=sys.stderr) + self.sd_model = None + + return self.sd_model + + def set_sd_model(self, v): + self.sd_model = v + + +model_data = SdModelData() + + +def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - if shared.sd_model: - sd_hijack.model_hijack.undo_hijack(shared.sd_model) - shared.sd_model = None + if model_data.sd_model: + sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + model_data.sd_model = None gc.collect() devices.torch_gc() @@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ timer.record("hijack") sd_model.eval() - shared.sd_model = sd_model + model_data.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_info = info or select_checkpoint() if not sd_model: - sd_model = shared.sd_model + sd_model = model_data.sd_model if sd_model is None: # previous model load failed current_checkpoint_info = None @@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info, already_loaded_state_dict=state_dict) - return shared.sd_model + return model_data.sd_model try: load_model_weights(sd_model, checkpoint_info, state_dict, timer) @@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None): return sd_model + def unload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack timer = Timer() - if shared.sd_model: - - # shared.sd_model.cond_stage_model.to(devices.cpu) - # shared.sd_model.first_stage_model.to(devices.cpu) - shared.sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(shared.sd_model) - shared.sd_model = None + if model_data.sd_model: + model_data.sd_model.to(devices.cpu) + sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + model_data.sd_model = None sd_model = None gc.collect() devices.torch_gc() -- cgit v1.2.1