diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/modules/devices.py b/modules/devices.py index 25008a04..30d30b99 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,12 +1,16 @@ import torch - # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility has_mps = getattr(torch, 'has_mps', False) +cpu = torch.device("cpu") + + def get_optimal_device(): - if torch.cuda.is_available(): - return torch.device("cuda") - if has_mps: - return torch.device("mps") - return torch.device("cpu") + if torch.cuda.is_available(): + return torch.device("cuda") + + if has_mps: + return torch.device("mps") + + return cpu |