aboutsummaryrefslogtreecommitdiff
path: root/modules/mac_specific.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-03-11 17:35:17 -0500
committerbrkirch <brkirch@users.noreply.github.com>2023-03-11 17:35:17 -0500
commita4cb96d4ae82741be9f0d072a37af3ae39521379 (patch)
tree7415bc6123a43d77b5d62a4db43cb6d8ed2b7e72 /modules/mac_specific.py
parent27e319dc4f09a2f040043948e5c52965976f8491 (diff)
Remove test, use bool tensor fix by default
The test isn't working correctly on macOS 13.3 and the bool tensor fix for cumsum is currently always needed anyway, so enable the fix by default.
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r--modules/mac_specific.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index ddcea53b..18e6ff72 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
@@ -45,7 +45,6 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)