From b50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9 Mon Sep 17 00:00:00 2001 From: Josh Watzman Date: Thu, 27 Oct 2022 21:59:16 +0100 Subject: Reduce peak memory usage when changing models A few tweaks to reduce peak memory usage, the biggest being that if we aren't using the checkpoint cache, we shouldn't duplicate the model state dict just to immediately throw it away. On my machine with 16GB of RAM, this change means I can typically change models, whereas before it would typically OOM. --- modules/sd_models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e697bb72..203e99a8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info): print(f"Global Step: {pl_sd['global_step']}") sd = get_state_dict_from_checkpoint(pl_sd) - missing, extra = model.load_state_dict(sd, strict=False) + del pl_sd + model.load_state_dict(sd, strict=False) + del sd if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -194,9 +196,10 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.to(devices.dtype_vae) - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) # LRU + if shared.opts.sd_checkpoint_cache > 0: + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") checkpoints_loaded.move_to_end(checkpoint_info) -- cgit v1.2.1 From 5d5dc64064d8ca399a76fe44dbb62bdef6c4b7c4 Mon Sep 17 00:00:00 2001 From: Antonio Date: Fri, 28 Oct 2022 05:49:39 +0200 Subject: Natural sorting for dropdown checkpoint list Example: Before After 11.ckpt 11.ckpt ab.ckpt ab.ckpt ade_pablo_step_1000.ckpt ade_pablo_step_500.ckpt ade_pablo_step_500.ckpt ade_pablo_step_1000.ckpt ade_step_1000.ckpt ade_step_500.ckpt ade_step_1500.ckpt ade_step_1000.ckpt ade_step_2000.ckpt ade_step_1500.ckpt ade_step_2500.ckpt ade_step_2000.ckpt ade_step_3000.ckpt ade_step_2500.ckpt ade_step_500.ckpt ade_step_3000.ckpt atp_step_5500.ckpt atp_step_5500.ckpt model1.ckpt model1.ckpt model10.ckpt model10.ckpt model1000.ckpt model33.ckpt model33.ckpt model50.ckpt model400.ckpt model400.ckpt model50.ckpt model1000.ckpt moo44.ckpt moo44.ckpt v1-4-pruned-emaonly.ckpt v1-4-pruned-emaonly.ckpt v1-5-pruned-emaonly.ckpt v1-5-pruned-emaonly.ckpt v1-5-pruned.ckpt v1-5-pruned.ckpt v1-5-vae.ckpt v1-5-vae.ckpt --- modules/sd_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e697bb72..64d5ee0d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -3,6 +3,7 @@ import os.path import sys from collections import namedtuple import torch +import re from omegaconf import OmegaConf from ldm.util import instantiate_from_config @@ -35,8 +36,10 @@ def setup_model(): list_models() -def checkpoint_tiles(): - return sorted([x.title for x in checkpoints_list.values()]) +def checkpoint_tiles(): + convert = lambda name: int(name) if name.isdigit() else name.lower() + alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] + return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) def list_models(): -- cgit v1.2.1