aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorunknown <mcgpapu@gmail.com>2023-02-03 20:39:42 -0600
committerunknown <mcgpapu@gmail.com>2023-02-03 20:39:42 -0600
commit5e1f4f7464e560a2468812fc9d5cb38659cff86c (patch)
tree6b3e6676384fae53f3359aeea9ac51d32a5affd6 /modules/devices.py
parentade40aa1a0605ba4aa3adc734ffb2b5121729d03 (diff)
parent226d840e84c5f306350b0681945989b86760e616 (diff)
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui into gamepad
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py49
1 files changed, 15 insertions, 34 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 655ca1d3..919048d0 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -2,6 +2,7 @@ import sys, os, shlex
import contextlib
import torch
from modules import errors
+from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -156,36 +157,7 @@ def test_for_nans(x, where):
raise NansException(message)
-# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-orig_tensor_to = torch.Tensor.to
-def tensor_to_fix(self, *args, **kwargs):
- if self.device.type != 'mps' and \
- ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
- (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
- self = self.contiguous()
- return orig_tensor_to(self, *args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
-orig_layer_norm = torch.nn.functional.layer_norm
-def layer_norm_fix(*args, **kwargs):
- if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
- args = list(args)
- args[0] = args[0].contiguous()
- return orig_layer_norm(*args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
-orig_tensor_numpy = torch.Tensor.numpy
-def numpy_fix(self, *args, **kwargs):
- if self.requires_grad:
- self = self.detach()
- return orig_tensor_numpy(self, *args, **kwargs)
-
-
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
-orig_cumsum = torch.cumsum
-orig_Tensor_cumsum = torch.Tensor.cumsum
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
@@ -199,11 +171,20 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
if has_mps():
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
- torch.Tensor.to = tensor_to_fix
- torch.nn.functional.layer_norm = layer_norm_fix
- torch.Tensor.numpy = numpy_fix
+
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+ CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
+ lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+ CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
- torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
- torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
+ cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
+ CondFunc('torch.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
+