From eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:49:05 +0800 Subject: Add CPU fp8 support Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet. --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563..0cd2b55d 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,6 +71,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -93,10 +94,13 @@ def cond_cast_float(input): nv_rng = None -def autocast(disable=False): +def autocast(disable=False, unet=False): if disable: return contextlib.nullcontext() + if unet and fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() -- cgit v1.2.1 From d4d3134f6d2d232c7bcfa80900a362921e644976 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Oct 2023 15:24:26 +0800 Subject: ManualCast for 10/16 series gpu --- modules/devices.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 6 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d..c05f2b35 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None - - -def autocast(disable=False, unet=False): +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() -- cgit v1.2.1 From ddc2a3499b8cd120b4a42358bcd33137ce1d1e75 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Sat, 28 Oct 2023 16:52:35 +0800 Subject: Add MPS manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index c05f2b35..d7c905c2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -121,6 +121,8 @@ def manual_autocast(): def manual_cast_forward(self, *args, **kwargs): org_dtype = next(self.parameters()).dtype self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} result = self.org_forward(*args, **kwargs) self.to(org_dtype) return result @@ -136,7 +138,6 @@ def manual_autocast(): def autocast(disable=False): - print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() @@ -146,6 +147,9 @@ def autocast(disable=False): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): return manual_autocast() + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() -- cgit v1.2.1 From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: Use options instead of cmd_args --- modules/devices.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index d7c905c2..03e7bdb7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ patch_module_list = [ torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward -- cgit v1.2.1 From 043d2edcf6a543f236f1f3cb70ac72e7b3b357b6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:31 +0800 Subject: Better naming --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 03e7bdb7..c19a7f40 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs): @contextlib.contextmanager -def manual_autocast(): +def manual_cast(): for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward @@ -148,10 +148,10 @@ def autocast(disable=False): return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_autocast() + return manual_cast() if has_mps() and shared.cmd_opts.precision != "full": - return manual_autocast() + return manual_cast() if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() -- cgit v1.2.1