aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/devices.py30
1 files changed, 11 insertions, 19 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 93d82bbc..dd50fe24 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -24,17 +24,18 @@ def extract_device_id(args, name):
return None
-def get_optimal_device():
- if torch.cuda.is_available():
- from modules import shared
+def get_cuda_device_string():
+ from modules import shared
+
+ if shared.cmd_opts.device_id is not None:
+ return f"cuda:{shared.cmd_opts.device_id}"
- device_id = shared.cmd_opts.device_id
+ return "cuda"
- if device_id is not None:
- cuda_device = f"cuda:{device_id}"
- return torch.device(cuda_device)
- else:
- return torch.device("cuda")
+
+def get_optimal_device():
+ if torch.cuda.is_available():
+ return torch.device(get_cuda_device_string())
if has_mps():
return torch.device("mps")
@@ -44,16 +45,7 @@ def get_optimal_device():
def torch_gc():
if torch.cuda.is_available():
- from modules import shared
-
- device_id = shared.cmd_opts.device_id
-
- if device_id is not None:
- cuda_device = f"cuda:{device_id}"
- else:
- cuda_device = "cuda"
-
- with torch.cuda.device(cuda_device):
+ with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()