aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py24
1 files changed, 13 insertions, 11 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 6e8277e5..f00079c6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -25,17 +25,18 @@ def extract_device_id(args, name):
return None
-def get_optimal_device():
- if torch.cuda.is_available():
- from modules import shared
+def get_cuda_device_string():
+ from modules import shared
+
+ if shared.cmd_opts.device_id is not None:
+ return f"cuda:{shared.cmd_opts.device_id}"
- device_id = shared.cmd_opts.device_id
+ return "cuda"
- if device_id is not None:
- cuda_device = f"cuda:{device_id}"
- return torch.device(cuda_device)
- else:
- return torch.device("cuda")
+
+def get_optimal_device():
+ if torch.cuda.is_available():
+ return torch.device(get_cuda_device_string())
if has_mps():
return torch.device("mps")
@@ -45,8 +46,9 @@ def get_optimal_device():
def torch_gc():
if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+ with torch.cuda.device(get_cuda_device_string()):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
def enable_tf32():