aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-31 07:38:34 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-31 07:38:34 +0300
commit5ef669de080814067961f28357256e8fe27544f4 (patch)
tree655f4582e692f0fc3667b3b668ad365ac3ab92ae /modules/sd_models.py
parentc9c8485bc1e8720aba70f029d25cba1c4abf2b5c (diff)
parente7965a5eb804a51e949df07c66c0b7c61ab7fa7b (diff)
Merge branch 'release_candidate'
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py295
1 files changed, 235 insertions, 60 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index fb31a793..930d0bee 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,8 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
-from modules.sd_hijack_inpainting import do_inpainting_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules.timer import Timer
import tomesd
@@ -28,11 +27,31 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
+def replace_key(d, key, new_key, value):
+ keys = list(d.keys())
+
+ d[new_key] = value
+
+ if key not in keys:
+ return d
+
+ index = keys.index(key)
+ keys[index] = new_key
+
+ new_d = {k: d[k] for k in keys}
+
+ d.clear()
+ d.update(new_d)
+ return d
+
+
class CheckpointInfo:
def __init__(self, filename):
self.filename = filename
abspath = os.path.abspath(filename)
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
@@ -43,6 +62,19 @@ class CheckpointInfo:
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
+ def read_metadata():
+ metadata = read_metadata_from_safetensors(filename)
+ self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
+
+ return metadata
+
+ self.metadata = {}
+ if self.is_safetensors:
+ try:
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
+ except Exception as e:
+ errors.display(e, f"reading metadata for {filename}")
+
self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
@@ -52,17 +84,11 @@ class CheckpointInfo:
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
-
- self.metadata = {}
-
- _, ext = os.path.splitext(self.filename)
- if ext.lower() == ".safetensors":
- try:
- self.metadata = read_metadata_from_safetensors(filename)
- except Exception as e:
- errors.display(e, f"reading checkpoint metadata: {filename}")
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
+ if self.shorthash:
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
def register(self):
checkpoints_list[self.title] = self
@@ -74,13 +100,20 @@ class CheckpointInfo:
if self.sha256 is None:
return
- self.shorthash = self.sha256[0:10]
+ shorthash = self.sha256[0:10]
+ if self.shorthash == self.sha256[0:10]:
+ return self.shorthash
+
+ self.shorthash = shorthash
if self.shorthash not in self.ids:
- self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
- checkpoints_list.pop(self.title)
+ old_title = self.title
self.title = f'{self.name} [{self.shorthash}]'
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
+
+ replace_key(checkpoints_list, old_title, self.title, self)
self.register()
return self.shorthash
@@ -101,14 +134,8 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
-
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
-
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
def list_models():
@@ -131,12 +158,18 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
- for filename in sorted(model_list, key=str.lower):
+ for filename in model_list:
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
def get_closet_checkpoint_match(search_string):
+ if not search_string:
+ return None
+
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -145,6 +178,11 @@ def get_closet_checkpoint_match(search_string):
if found:
return found[0]
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+
return None
@@ -280,11 +318,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return res
+class SkipWritingToConfig:
+ """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
+
+ skip = False
+ previous = None
+
+ def __enter__(self):
+ self.previous = SkipWritingToConfig.skip
+ SkipWritingToConfig.skip = True
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ SkipWritingToConfig.skip = self.previous
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
@@ -297,18 +351,23 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False)
- del state_dict
timer.record("apply weights to model")
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ checkpoints_loaded[checkpoint_info] = state_dict
+
+ del state_dict
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
- if not shared.cmd_opts.no_half:
+ if shared.cmd_opts.no_half:
+ model.float()
+ devices.dtype_unet = torch.float32
+ timer.record("apply float()")
+ else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
@@ -324,9 +383,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if depth_model:
model.depth_model = depth_model
+ devices.dtype_unet = torch.float16
timer.record("apply half()")
- devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -346,7 +405,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")
@@ -423,6 +482,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
def __init__(self):
self.sd_model = None
+ self.loaded_sd_models = []
self.was_loaded_at_least_once = False
self.lock = threading.Lock()
@@ -437,6 +497,7 @@ class SdModelData:
try:
load_model()
+
except Exception as e:
errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
@@ -445,14 +506,30 @@ class SdModelData:
return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae.base_vae = getattr(v, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
+
+ try:
+ self.loaded_sd_models.remove(v)
+ except ValueError:
+ pass
+
+ if v is not None:
+ self.loaded_sd_models.insert(0, v)
model_data = SdModelData()
def get_empty_cond(sd_model):
+
+ p = processing.StableDiffusionProcessingTxt2Img()
+ extra_networks.activate(p, {})
+
if hasattr(sd_model, 'conditioner'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
@@ -460,20 +537,46 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
+def send_model_to_cpu(m):
+ if m.lowvram:
+ lowvram.send_everything_to_cpu()
+ else:
+ m.to(devices.cpu)
+
+ devices.torch_gc()
+
+
+def model_target_device(m):
+ if lowvram.is_needed(m):
+ return devices.cpu
+ else:
+ return devices.device
+
+
+def send_model_to_device(m):
+ lowvram.apply(m)
+
+ if not m.lowvram:
+ m.to(shared.device)
+
+
+def send_model_to_trash(m):
+ m.to(device="meta")
+ devices.torch_gc()
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
- from modules import lowvram, sd_hijack
+ from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
+ timer = Timer()
+
if model_data.sd_model:
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ send_model_to_trash(model_data.sd_model)
model_data.sd_model = None
- gc.collect()
devices.torch_gc()
- do_inpainting_hijack()
-
- timer = Timer()
+ timer.record("unload existing model")
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
@@ -495,25 +598,35 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
- sd_model = instantiate_from_config(sd_config.model)
- except Exception:
- pass
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+
+ except Exception as e:
+ errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
- sd_model = instantiate_from_config(sd_config.model)
+
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
-
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
else:
- sd_model.to(shared.device)
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ timer.record("load weights from state dict")
+ send_model_to_device(sd_model)
timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model)
@@ -521,7 +634,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("hijack")
sd_model.eval()
- model_data.sd_model = sd_model
+ model_data.set_sd_model(sd_model)
model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -542,10 +655,70 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
return sd_model
+def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
+ """
+ Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
+ If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
+ If not, returns the model that can be used to load weights from checkpoint_info's file.
+ If no such model exists, returns None.
+ Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
+ """
+
+ already_loaded = None
+ for i in reversed(range(len(model_data.loaded_sd_models))):
+ loaded_model = model_data.loaded_sd_models[i]
+ if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ already_loaded = loaded_model
+ continue
+
+ if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
+ print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
+ model_data.loaded_sd_models.pop()
+ send_model_to_trash(loaded_model)
+ timer.record("send model to trash")
+
+ if shared.opts.sd_checkpoints_keep_in_cpu:
+ send_model_to_cpu(sd_model)
+ timer.record("send model to cpu")
+
+ if already_loaded is not None:
+ send_model_to_device(already_loaded)
+ timer.record("send model to device")
+
+ model_data.set_sd_model(already_loaded, already_loaded=True)
+
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
+ shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
+
+ print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ sd_vae.reload_vae_weights(already_loaded)
+ return model_data.sd_model
+ elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
+ print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
+
+ model_data.sd_model = None
+ load_model(checkpoint_info)
+ return model_data.sd_model
+ elif len(model_data.loaded_sd_models) > 0:
+ sd_model = model_data.loaded_sd_models.pop()
+ model_data.sd_model = sd_model
+
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
+
+ print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
+ return sd_model
+ else:
+ return None
+
+
def reload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+ timer = Timer()
+
if not sd_model:
sd_model = model_data.sd_model
@@ -554,19 +727,17 @@ def reload_model_weights(sd_model=None, info=None):
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
- return
-
- sd_unet.apply_unet("None")
+ return sd_model
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
+ sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
+ if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ return sd_model
+ if sd_model is not None:
+ sd_unet.apply_unet("None")
+ send_model_to_cpu(sd_model)
sd_hijack.model_hijack.undo_hijack(sd_model)
- timer = Timer()
-
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -574,7 +745,9 @@ def reload_model_weights(sd_model=None, info=None):
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
- del sd_model
+ if sd_model is not None:
+ send_model_to_trash(sd_model)
+
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
@@ -591,17 +764,19 @@ def reload_model_weights(sd_model=None, info=None):
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
print(f"Weights loaded in {timer.summary()}.")
+ model_data.set_sd_model(sd_model)
+ sd_unet.apply_unet()
+
return sd_model
def unload_model_weights(sd_model=None, info=None):
- from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model: