From 7738c057ce938ca5c5a53a95e2023d3bcf14f06a Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 1 Feb 2023 05:23:58 -0500 Subject: MPS fix is still needed :( Apparently I did not test with large enough images to trigger the bug with torch.narrow on MPS --- modules/devices.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 655ca1d3..f4afb897 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -207,3 +207,6 @@ if has_mps(): cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) + -- cgit v1.2.1 From 2217331cd1245d0bdda786a5dcaf4f7b843bd7e4 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 1 Feb 2023 06:20:19 -0500 Subject: Refactor MPS fixes to CondFunc --- modules/devices.py | 50 ++++++++++++++------------------------------------ 1 file changed, 14 insertions(+), 36 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index f4afb897..919048d0 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,6 +2,7 @@ import sys, os, shlex import contextlib import torch from modules import errors +from modules.sd_hijack_utils import CondFunc from packaging import version @@ -156,36 +157,7 @@ def test_for_nans(x, where): raise NansException(message) -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -orig_tensor_to = torch.Tensor.to -def tensor_to_fix(self, *args, **kwargs): - if self.device.type != 'mps' and \ - ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ - (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): - self = self.contiguous() - return orig_tensor_to(self, *args, **kwargs) - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 -orig_layer_norm = torch.nn.functional.layer_norm -def layer_norm_fix(*args, **kwargs): - if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': - args = list(args) - args[0] = args[0].contiguous() - return orig_layer_norm(*args, **kwargs) - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 -orig_tensor_numpy = torch.Tensor.numpy -def numpy_fix(self, *args, **kwargs): - if self.requires_grad: - self = self.detach() - return orig_tensor_numpy(self, *args, **kwargs) - - # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 -orig_cumsum = torch.cumsum -orig_Tensor_cumsum = torch.Tensor.cumsum def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': output_dtype = kwargs.get('dtype', input.dtype) @@ -199,14 +171,20 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): if has_mps(): if version.parse(torch.__version__) < version.parse("1.13"): # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix + + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 + CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), + lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), + lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') + # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 + CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) - torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) - torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) - orig_narrow = torch.narrow - torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) + cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) + CondFunc('torch.cumsum', cumsum_fix_func, None) + CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) + CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) -- cgit v1.2.1 From 1b8af15f13cba2bfce249d9837660ea4f28d451e Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 1 Feb 2023 09:28:16 -0500 Subject: Refactor Mac specific code to a separate file Move most Mac related code to a separate file, don't even load it unless web UI is run under macOS. --- modules/devices.py | 52 +++++++--------------------------------------------- 1 file changed, 7 insertions(+), 45 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 919048d0..52c3e7cd 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,22 +1,17 @@ -import sys, os, shlex +import sys import contextlib import torch from modules import errors -from modules.sd_hijack_utils import CondFunc -from packaging import version + +if sys.platform == "darwin": + from modules import mac_specific -# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. -# check `getattr` and try it for compatibility def has_mps() -> bool: - if not getattr(torch, 'has_mps', False): - return False - try: - torch.zeros(1).to(torch.device("mps")) - return True - except Exception: + if sys.platform != "darwin": return False - + else: + return mac_specific.has_mps def extract_device_id(args, name): for x in range(len(args)): @@ -155,36 +150,3 @@ def test_for_nans(x, where): message += " Use --disable-nan-check commandline argument to disable this check." raise NansException(message) - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 -def cumsum_fix(input, cumsum_func, *args, **kwargs): - if input.device.type == 'mps': - output_dtype = kwargs.get('dtype', input.dtype) - if output_dtype == torch.int64: - return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) - elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): - return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) - return cumsum_func(input, *args, **kwargs) - - -if has_mps(): - if version.parse(torch.__version__) < version.parse("1.13"): - # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working - - # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 - CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), - lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) - # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 - CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), - lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') - # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 - CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) - elif version.parse(torch.__version__) > version.parse("1.13.1"): - cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) - cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) - cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) - CondFunc('torch.cumsum', cumsum_fix_func, None) - CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) - CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) - -- cgit v1.2.1