aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-10-25 11:36:43 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-10-25 11:36:43 +0800
commit1df6c8bfec4715610d64684b6ad2fa38c76c1df6 (patch)
tree40ad7dc479ef592433981339f43968bcb48e7d2b
parent9c1eba2af3a6f9cd6282b3a367656793cbe70c01 (diff)
fp8 for TE
-rw-r--r--modules/sd_models.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 44d4038b..69395294 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")