aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_models.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 44d4038b..69395294 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")