aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py32
1 files changed, 25 insertions, 7 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 7511e1dc..67165bf6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,16 +3,27 @@ import contextlib
import torch
from modules import errors
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
-has_mps = getattr(torch, 'has_mps', False)
-cpu = torch.device("cpu")
+# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
+# check `getattr` and try it for compatibility
+def has_mps() -> bool:
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+
def extract_device_id(args, name):
for x in range(len(args)):
- if name in args[x]: return args[x+1]
+ if name in args[x]:
+ return args[x + 1]
+
return None
+
def get_optimal_device():
if torch.cuda.is_available():
from modules import shared
@@ -25,7 +36,7 @@ def get_optimal_device():
else:
return torch.device("cuda")
- if has_mps:
+ if has_mps():
return torch.device("mps")
return cpu
@@ -45,10 +56,12 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
+cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
+
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
@@ -82,6 +95,11 @@ def autocast(disable=False):
return torch.autocast("cuda")
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
-def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
+def mps_contiguous(input_tensor, device):
+ return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+
+
+def mps_contiguous_to(input_tensor, device):
+ return mps_contiguous(input_tensor, device).to(device)