aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorExtraltodeus <extraltodeus@gmail.com>2022-10-22 00:11:07 +0200
committerGitHub <noreply@github.com>2022-10-22 00:11:07 +0200
commit57eb54b838faa383c10079e1bb5471b7bee6a695 (patch)
treeeb18f6c912448f17d12fd9f13c85b553ef5387ad /modules/devices.py
parentf49c08ea566385db339c6628f65c3a121033f67c (diff)
implement CUDA device selection by ID
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..8a159282 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ # CUDA device selection support:
+ if "shared" not in sys.modules:
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
+ sys.argv += shlex.split(commandline_args)
+ device_id = extract_device_id(sys.argv, '--device-id')
+ else:
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")