aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-10-22 13:58:00 +0300
committerGitHub <noreply@github.com>2022-10-22 13:58:00 +0300
commite80bdcab91df0d91fa268991bee1d0143e81920a (patch)
tree347f8cbcdf644885fcf3481ed7a2dc55f8942c6e /modules/devices.py
parent5aa9525046b7520d39fe8fc8c5c6cc10ab4d5fdb (diff)
parent1fa53dab2c5a857b9773f904fadf853dac1f1bd6 (diff)
Merge pull request #3377 from Extraltodeus/cuda-device-id-selection
Implementation of CUDA device id selection (--device-id 0/1/2)
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..8a159282 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ # CUDA device selection support:
+ if "shared" not in sys.modules:
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
+ sys.argv += shlex.split(commandline_args)
+ device_id = extract_device_id(sys.argv, '--device-id')
+ else:
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")