aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py48
1 files changed, 38 insertions, 10 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f6fbdcd6..547e93c4 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -343,7 +343,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
- if not shared.cmd_opts.no_half:
+ if shared.cmd_opts.no_half:
+ model.float()
+ devices.dtype_unet = torch.float32
+ timer.record("apply float()")
+ else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
@@ -359,9 +363,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if depth_model:
model.depth_model = depth_model
+ devices.dtype_unet = torch.float16
timer.record("apply half()")
- devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -482,8 +486,12 @@ class SdModelData:
return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae.base_vae = getattr(v, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
@@ -510,7 +518,7 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
@@ -518,10 +526,17 @@ def send_model_to_cpu(m):
devices.torch_gc()
-def send_model_to_device(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+def model_target_device(m):
+ if lowvram.is_needed(m):
+ return devices.cpu
else:
+ return devices.device
+
+
+def send_model_to_device(m):
+ lowvram.apply(m)
+
+ if not m.lowvram:
m.to(shared.device)
@@ -579,7 +594,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("create model")
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
+ else:
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
@@ -642,13 +665,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
send_model_to_device(already_loaded)
timer.record("send model to device")
- model_data.set_sd_model(already_loaded)
+ model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
@@ -660,6 +684,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
@@ -716,7 +744,7 @@ def reload_model_weights(sd_model=None, info=None):
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")