aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 07:52:50 +0300
committerGitHub <noreply@github.com>2023-08-05 07:52:50 +0300
commitc613416af375092f55b9bc8649c949e95d250c44 (patch)
tree69e3bf4d53f4113f192116c252a2e410bb5b1f90 /modules
parent0ae2767ae6bb775de448b0d8cda1806edb2aef67 (diff)
parent22ecb78b51f7e6f0234cbc0efbde4ee9a2dc466f (diff)
Merge pull request #12227 from AUTOMATIC1111/multiple_loaded_models
option to keep multiple models in memory
Diffstat (limited to 'modules')
-rw-r--r--modules/lowvram.py3
-rw-r--r--modules/sd_hijack.py10
-rw-r--r--modules/sd_hijack_inpainting.py2
-rw-r--r--modules/sd_models.py135
-rw-r--r--modules/sd_models_xl.py8
-rw-r--r--modules/shared.py12
6 files changed, 135 insertions, 35 deletions
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 3f830664..96f52b7b 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -15,6 +15,9 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram):
+ if getattr(sd_model, 'lowvram', False):
+ return
+
sd_model.lowvram = True
parents = {}
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 609fd56c..9ad98199 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
+ldm.modules.attention.print = shared.ldm_print
+ldm.modules.diffusionmodules.model.print = shared.ldm_print
+ldm.util.print = shared.ldm_print
+ldm.models.diffusion.ddpm.print = shared.ldm_print
+
+sd_hijack_inpainting.do_inpainting_hijack()
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index c1977b19..2d44b856 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def do_inpainting_hijack():
- # p_sample_plms is needed because PLMS can't work with dicts as conditionings
-
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 1608b37f..f6051604 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
-from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -434,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
def __init__(self):
self.sd_model = None
+ self.loaded_sd_models = []
self.was_loaded_at_least_once = False
self.lock = threading.Lock()
@@ -448,6 +448,7 @@ class SdModelData:
try:
load_model()
+
except Exception as e:
errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
@@ -459,11 +460,24 @@ class SdModelData:
def set_sd_model(self, v):
self.sd_model = v
+ try:
+ self.loaded_sd_models.remove(v)
+ except ValueError:
+ pass
+
+ if v is not None:
+ self.loaded_sd_models.insert(0, v)
+
model_data = SdModelData()
def get_empty_cond(sd_model):
+ from modules import extra_networks, processing
+
+ p = processing.StableDiffusionProcessingTxt2Img()
+ extra_networks.activate(p, {})
+
if hasattr(sd_model, 'conditioner'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
@@ -471,19 +485,43 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
+def send_model_to_cpu(m):
+ from modules import lowvram
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ else:
+ m.to(devices.cpu)
+
+ devices.torch_gc()
+
+
+def send_model_to_device(m):
+ from modules import lowvram
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+ else:
+ m.to(shared.device)
+
+
+def send_model_to_trash(m):
+ m.to(device="meta")
+ devices.torch_gc()
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
- from modules import lowvram, sd_hijack
+ from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
+ timer = Timer()
+
if model_data.sd_model:
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ send_model_to_trash(model_data.sd_model)
model_data.sd_model = None
- gc.collect()
devices.torch_gc()
- do_inpainting_hijack()
-
- timer = Timer()
+ timer.record("unload existing model")
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
@@ -523,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ timer.record("load weights from state dict")
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
- else:
- sd_model.to(shared.device)
-
+ send_model_to_device(sd_model)
timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model)
@@ -536,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("hijack")
sd_model.eval()
- model_data.sd_model = sd_model
+ model_data.set_sd_model(sd_model)
model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -557,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
return sd_model
+def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
+ """
+ Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
+ If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
+ If not, returns the model that can be used to load weights from checkpoint_info's file.
+ If no such model exists, returns None.
+ Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
+ """
+
+ already_loaded = None
+ for i in reversed(range(len(model_data.loaded_sd_models))):
+ loaded_model = model_data.loaded_sd_models[i]
+ if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ already_loaded = loaded_model
+ continue
+
+ if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
+ print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
+ model_data.loaded_sd_models.pop()
+ send_model_to_trash(loaded_model)
+ timer.record("send model to trash")
+
+ if shared.opts.sd_checkpoints_keep_in_cpu:
+ send_model_to_cpu(sd_model)
+ timer.record("send model to cpu")
+
+ if already_loaded is not None:
+ send_model_to_device(already_loaded)
+ timer.record("send model to device")
+
+ model_data.set_sd_model(already_loaded)
+ print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ return model_data.sd_model
+ elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
+ print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
+
+ model_data.sd_model = None
+ load_model(checkpoint_info)
+ return model_data.sd_model
+ elif len(model_data.loaded_sd_models) > 0:
+ sd_model = model_data.loaded_sd_models.pop()
+ model_data.sd_model = sd_model
+
+ print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
+ return sd_model
+ else:
+ return None
+
+
def reload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+ timer = Timer()
+
if not sd_model:
sd_model = model_data.sd_model
@@ -569,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
- return
-
- sd_unet.apply_unet("None")
+ return sd_model
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
+ sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
+ if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ return sd_model
+ if sd_model is not None:
+ sd_unet.apply_unet("None")
+ send_model_to_cpu(sd_model)
sd_hijack.model_hijack.undo_hijack(sd_model)
- timer = Timer()
-
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -590,9 +674,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
if sd_model is not None:
- sd_model.to(device="meta")
+ send_model_to_trash(sd_model)
- devices.torch_gc()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
@@ -615,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
+ model_data.set_sd_model(sd_model)
+
return sd_model
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index bc219508..01123321 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -98,10 +98,10 @@ def extend_sdxl(model):
model.conditioner.wrapped = torch.nn.Module()
-sgm.modules.attention.print = lambda *args: None
-sgm.modules.diffusionmodules.model.print = lambda *args: None
-sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
-sgm.modules.encoders.modules.print = lambda *args: None
+sgm.modules.attention.print = shared.ldm_print
+sgm.modules.diffusionmodules.model.print = shared.ldm_print
+sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
+sgm.modules.encoders.modules.print = shared.ldm_print
# this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True
diff --git a/modules/shared.py b/modules/shared.py
index 86ffb48f..8245250a 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -400,6 +400,7 @@ options_templates.update(options_section(('system', "System"), {
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
+ "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -419,7 +420,9 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+ "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
+ "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
@@ -906,3 +909,10 @@ def walk_files(path, allowed_extensions=None):
continue
yield os.path.join(root, filename)
+
+
+def ldm_print(*args, **kwargs):
+ if opts.hide_ldm_prints:
+ return
+
+ print(*args, **kwargs)