aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py106
1 files changed, 100 insertions, 6 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 9355f1e1..747fc39e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -15,6 +15,7 @@ from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
+from modules.shared import opts
import tomesd
import numpy as np
@@ -348,10 +349,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous
+def check_fp8(model):
+ if model is None:
+ return None
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.opts.fp8_storage == "Enable":
+ enable_fp8 = True
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+ return enable_fp8
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
+ if devices.fp8:
+ # prevent model to load state dict in fp8
+ model.half()
+
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -383,6 +402,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.no_half:
model.float()
+ model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
@@ -396,7 +416,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
+ alphas_cumprod = model.alphas_cumprod
+ model.alphas_cumprod = None
model.half()
+ model.alphas_cumprod = alphas_cumprod
+ model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
@@ -404,6 +428,30 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
+ apply_alpha_schedule_override(model)
+
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
+ if check_fp8(model):
+ devices.fp8 = True
+ first_stage = model.first_stage_model
+ model.first_stage_model = None
+ for module in model.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.data.clone().cpu().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.data.clone().cpu().half()
+ module.to(torch.float8_e4m3fn)
+ model.first_stage_model = first_stage
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -505,6 +553,48 @@ def repair_config(sd_config):
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
+def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+
+def apply_alpha_schedule_override(sd_model, p=None):
+ """
+ Applies an override to the alpha schedule of the model according to settings.
+ - downcasts the alpha schedule to half precision
+ - rescales the alpha schedule to have zero terminal SNR
+ """
+
+ if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
+ return
+
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ if p is not None:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
+
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ if p is not None:
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
+
+
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
@@ -651,6 +741,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
weight_dtype_conversion = {
'first_stage_model': None,
+ 'alphas_cumprod': None,
'': torch.float16,
}
@@ -746,7 +837,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -758,11 +849,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if check_fp8(sd_model) != devices.fp8:
+ # load from state dict again to prevent extra numerical errors
+ forced_reload = True
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
@@ -793,13 +887,13 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
- script_callbacks.model_loaded_callback(sd_model)
- timer.record("script callbacks")
-
if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
+ script_callbacks.model_loaded_callback(sd_model)
+ timer.record("script callbacks")
+
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)