From 1b8af15f13cba2bfce249d9837660ea4f28d451e Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 1 Feb 2023 09:28:16 -0500 Subject: Refactor Mac specific code to a separate file Move most Mac related code to a separate file, don't even load it unless web UI is run under macOS. --- modules/mac_specific.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 modules/mac_specific.py (limited to 'modules/mac_specific.py') diff --git a/modules/mac_specific.py b/modules/mac_specific.py new file mode 100644 index 00000000..e39d670e --- /dev/null +++ b/modules/mac_specific.py @@ -0,0 +1,56 @@ +import torch +from modules import paths +from modules.sd_hijack_utils import CondFunc +from packaging import version + + +device = None + + +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. +# check `getattr` and try it for compatibility +def check_for_mps() -> bool: + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False +has_mps = check_for_mps() + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if output_dtype == torch.int64: + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) + return cumsum_func(input, *args, **kwargs) + + +if has_mps: + # MPS fix for randn in torchsde + CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') + + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 + CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), + lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), + lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') + # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 + CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) + elif version.parse(torch.__version__) > version.parse("1.13.1"): + cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) + cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) + cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) + CondFunc('torch.cumsum', cumsum_fix_func, None) + CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) + CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) + -- cgit v1.2.1 From 4306659c4dab1a2ae611ac2a7487b87e1c513adf Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 4 Feb 2023 01:22:06 -0500 Subject: Remove unused code --- modules/mac_specific.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/mac_specific.py') diff --git a/modules/mac_specific.py b/modules/mac_specific.py index e39d670e..ddcea53b 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -4,9 +4,6 @@ from modules.sd_hijack_utils import CondFunc from packaging import version -device = None - - # has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # check `getattr` and try it for compatibility def check_for_mps() -> bool: -- cgit v1.2.1