aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-04 19:56:35 +0300
committerGitHub <noreply@github.com>2023-01-04 19:56:35 +0300
commiteeb1de4388773ba92b9920a4f64eb91add2e02ca (patch)
tree22f5d5e7417f24599a415fd64c9f1652495ce5a3 /modules/devices.py
parentd85c2cb2d59f64cbb510a9e5596596de2e4f4dcc (diff)
parentb7deea47eeb033052062621b0005d4321b53bff7 (diff)
Merge branch 'master' into gradient-clipping
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py115
1 files changed, 84 insertions, 31 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 7511e1dc..800510b7 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -2,72 +2,95 @@ import sys, os, shlex
import contextlib
import torch
from modules import errors
+from packaging import version
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
-has_mps = getattr(torch, 'has_mps', False)
-cpu = torch.device("cpu")
+# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
+# check `getattr` and try it for compatibility
+def has_mps() -> bool:
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+
def extract_device_id(args, name):
for x in range(len(args)):
- if name in args[x]: return args[x+1]
+ if name in args[x]:
+ return args[x + 1]
+
return None
-def get_optimal_device():
- if torch.cuda.is_available():
- from modules import shared
- device_id = shared.cmd_opts.device_id
+def get_cuda_device_string():
+ from modules import shared
+
+ if shared.cmd_opts.device_id is not None:
+ return f"cuda:{shared.cmd_opts.device_id}"
- if device_id is not None:
- cuda_device = f"cuda:{device_id}"
- return torch.device(cuda_device)
- else:
- return torch.device("cuda")
+ return "cuda"
- if has_mps:
+
+def get_optimal_device():
+ if torch.cuda.is_available():
+ return torch.device(get_cuda_device_string())
+
+ if has_mps():
return torch.device("mps")
return cpu
+def get_device_for(task):
+ from modules import shared
+
+ if task in shared.cmd_opts.use_cpu:
+ return cpu
+
+ return get_optimal_device()
+
+
def torch_gc():
if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+ with torch.cuda.device(get_cuda_device_string()):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
def enable_tf32():
if torch.cuda.is_available():
+
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+ if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+ torch.backends.cudnn.benchmark = True
+
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
+
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
+cpu = torch.device("cpu")
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
-def randn(seed, shape):
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
- if device.type == 'mps':
- generator = torch.Generator(device=cpu)
- generator.manual_seed(seed)
- noise = torch.randn(shape, generator=generator, device=cpu).to(device)
- return noise
+def randn(seed, shape):
torch.manual_seed(seed)
+ if device.type == 'mps':
+ return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
- generator = torch.Generator(device=cpu)
- noise = torch.randn(shape, generator=generator, device=cpu).to(device)
- return noise
-
+ return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
@@ -82,6 +105,36 @@ def autocast(disable=False):
return torch.autocast("cuda")
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
-def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
+orig_tensor_to = torch.Tensor.to
+def tensor_to_fix(self, *args, **kwargs):
+ if self.device.type != 'mps' and \
+ ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
+ (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
+ self = self.contiguous()
+ return orig_tensor_to(self, *args, **kwargs)
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+orig_layer_norm = torch.nn.functional.layer_norm
+def layer_norm_fix(*args, **kwargs):
+ if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
+ args = list(args)
+ args[0] = args[0].contiguous()
+ return orig_layer_norm(*args, **kwargs)
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+orig_tensor_numpy = torch.Tensor.numpy
+def numpy_fix(self, *args, **kwargs):
+ if self.requires_grad:
+ self = self.detach()
+ return orig_tensor_numpy(self, *args, **kwargs)
+
+
+# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
+ torch.Tensor.to = tensor_to_fix
+ torch.nn.functional.layer_norm = layer_norm_fix
+ torch.Tensor.numpy = numpy_fix