diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/modules/devices.py b/modules/devices.py index 6e8277e5..f00079c6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -25,17 +25,18 @@ def extract_device_id(args, name): return None -def get_optimal_device(): - if torch.cuda.is_available(): - from modules import shared +def get_cuda_device_string(): + from modules import shared + + if shared.cmd_opts.device_id is not None: + return f"cuda:{shared.cmd_opts.device_id}" - device_id = shared.cmd_opts.device_id + return "cuda" - if device_id is not None: - cuda_device = f"cuda:{device_id}" - return torch.device(cuda_device) - else: - return torch.device("cuda") + +def get_optimal_device(): + if torch.cuda.is_available(): + return torch.device(get_cuda_device_string()) if has_mps(): return torch.device("mps") @@ -45,8 +46,9 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + with torch.cuda.device(get_cuda_device_string()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def enable_tf32(): |