aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
blob: 25008a04211a10f9305e6e2e6bb6367b04f1323b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
import torch


# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)

def get_optimal_device():
  if torch.cuda.is_available():
      return torch.device("cuda")
  if has_mps:
      return torch.device("mps")
  return torch.device("cpu")