From 7c128bbdac0da1767c239174e91af6f327845372 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:56:17 +0800 Subject: Add fp8 for sd unet --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea1..3b8ff820 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,6 +391,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + if shared.cmd_opts.opt_unet_fp8_storage: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.1 From 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:57:22 +0800 Subject: Add sdxl only arg --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b8ff820..08af128f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.opt_unet_fp8_storage: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.1 From eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:49:05 +0800 Subject: Add CPU fp8 support Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet. --- modules/sd_models.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 08af128f..c5fe57bf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,12 +391,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + + if shared.cmd_opts.opt_unet_fp8_storage: + enable_fp8 = True + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + enable_fp8 = True + + if enable_fp8: + devices.fp8 = True + if devices.device == devices.cpu: + for module in model.model.diffusion_model.modules(): + if isinstance(module, torch.nn.Conv2d): + module.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for cpu") + else: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.1 From 9c1eba2af3a6f9cd6282b3a367656793cbe70c01 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 02:11:27 +0800 Subject: Fix lint --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index c5fe57bf..44d4038b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,7 +396,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True - + if enable_fp8: devices.fp8 = True if devices.device == devices.cpu: -- cgit v1.2.1 From 1df6c8bfec4715610d64684b6ad2fa38c76c1df6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:36:43 +0800 Subject: fp8 for TE --- modules/sd_models.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 44d4038b..69395294 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) timer.record("apply fp8 unet for cpu") else: + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") -- cgit v1.2.1 From 4830b251366436ee8499c003fe87e46ddb4a4581 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:53:37 +0800 Subject: Fix alphas_cumprod dtype --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 69395294..23660454 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -416,6 +416,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.1 From bf5067f50ca32cd4764638702e3cc38bca8bfd8b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:54:28 +0800 Subject: Fix alphas cumprod --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 23660454..7ed89a9c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,6 +396,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True + else: + enable_fp8 = False if enable_fp8: devices.fp8 = True @@ -416,7 +418,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.1 From dda067f64d3289cee3ffd65767126cb30ae73b13 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:53:22 +0800 Subject: ignore mps for fp8 --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7ed89a9c..ccb6afd2 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -392,7 +392,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.cmd_opts.opt_unet_fp8_storage: enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True -- cgit v1.2.1 From d4d3134f6d2d232c7bcfa80900a362921e644976 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Oct 2023 15:24:26 +0800 Subject: ManualCast for 10/16 series gpu --- modules/sd_models.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2..31bcb913 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8: devices.fp8 = True + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + if devices.device == devices.cpu: for module in model.model.diffusion_model.modules(): if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) elif isinstance(module, torch.nn.Linear): module.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for cpu") else: - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet") + timer.record("apply fp8") + else: + devices.fp8 = False devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.1 From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: Use options instead of cmd_args --- modules/sd_models.py | 61 +++++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 29 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa..eb491434 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ class SkipWritingToConfig: SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if not check_fp8(model) and devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if devices.get_optimal_device_name() == "mps": - enable_fp8 = False - elif shared.cmd_opts.opt_unet_fp8_storage: - enable_fp8 = True - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - enable_fp8 = True - else: - enable_fp8 = False - - if enable_fp8: + if check_fp8(model): devices.fp8 = True - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) - - if devices.device == devices.cpu: - for module in model.model.diffusion_model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) - else: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage timer.record("apply fp8") else: devices.fp8 = False @@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: -- cgit v1.2.1 From 370a77f8e78e65a8a1339289d684cb43df142f70 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:59:34 +0800 Subject: Option for using fp16 weight when apply lora --- modules/sd_models.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index eb491434..0a7777f1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + for module in model.modules(): + if hasattr(module, 'fp16_weight'): + del module.fp16_weight + if hasattr(module, 'fp16_bias'): + del module.fp16_bias + if check_fp8(model): devices.fp8 = True first_stage = model.first_stage_model model.first_stage_model = None for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + if shared.opts.cache_fp16_weight: + module.fp16_weight = module.weight.clone().half() + if module.bias is not None: + module.fp16_bias = module.bias.clone().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") -- cgit v1.2.1 From 40ac134c553ac824d4a96666bba14d550300daa5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 25 Nov 2023 12:35:09 +0800 Subject: Fix pre-fp8 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a7777f1..90437c87 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if not check_fp8(model) and devices.fp8: + if devices.fp8: # prevent model to load state dict in fp8 model.half() -- cgit v1.2.1 From 50a21cb09fe3e9ea2d4fe058e0484e192c8a86e3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 2 Dec 2023 22:06:47 +0800 Subject: Ensure the cached weight will not be affected --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b8a9ae6..dcf816b3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: - module.fp16_weight = module.weight.clone().half() + module.fp16_weight = module.weight.data.clone().cpu().half() if module.bias is not None: - module.fp16_bias = module.bias.clone().half() + module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") -- cgit v1.2.1 From 672dc4efa8e0da38426b121e7c7216d0a8e465fd Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:16:10 +0800 Subject: Fix forced reload --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index dcf816b3..d0046f88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -801,7 +801,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if check_fp8(sd_model) != devices.fp8: # load from state dict again to prevent extra numerical errors forced_reload = True - elif sd_model.sd_model_checkpoint == checkpoint_info.filename: + elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) -- cgit v1.2.1