aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-02-01 09:28:16 -0500
committerbrkirch <brkirch@users.noreply.github.com>2023-02-01 14:05:56 -0500
commit1b8af15f13cba2bfce249d9837660ea4f28d451e (patch)
treed961d858e70941197701982cc92e6c18b3e76816 /modules/devices.py
parent226d840e84c5f306350b0681945989b86760e616 (diff)
Refactor Mac specific code to a separate file
Move most Mac related code to a separate file, don't even load it unless web UI is run under macOS.
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py52
1 files changed, 7 insertions, 45 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 919048d0..52c3e7cd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,22 +1,17 @@
-import sys, os, shlex
+import sys
import contextlib
import torch
from modules import errors
-from modules.sd_hijack_utils import CondFunc
-from packaging import version
+
+if sys.platform == "darwin":
+ from modules import mac_specific
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
def has_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
- try:
- torch.zeros(1).to(torch.device("mps"))
- return True
- except Exception:
+ if sys.platform != "darwin":
return False
-
+ else:
+ return mac_specific.has_mps
def extract_device_id(args, name):
for x in range(len(args)):
@@ -155,36 +150,3 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
-def cumsum_fix(input, cumsum_func, *args, **kwargs):
- if input.device.type == 'mps':
- output_dtype = kwargs.get('dtype', input.dtype)
- if output_dtype == torch.int64:
- return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
- return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
- return cumsum_func(input, *args, **kwargs)
-
-
-if has_mps():
- if version.parse(torch.__version__) < version.parse("1.13"):
- # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
-
- # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
- CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
- lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
- # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
- CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
- lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
- # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
- CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
- elif version.parse(torch.__version__) > version.parse("1.13.1"):
- cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
- cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
- CondFunc('torch.cumsum', cumsum_fix_func, None)
- CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
- CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
-