aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorcaptin411 <captindave@gmail.com>2022-10-25 13:22:27 -0700
committercaptin411 <captindave@gmail.com>2022-10-25 13:22:27 -0700
commit6629446a2f9bb3ade1c271854aae1530ba1a8cc3 (patch)
treead7cfd2b3f0208c24da64c7f08e0550e783228ec /modules/devices.py
parent3e6c2420c1177e9e79f2b566a5a7795b7416e34a (diff)
parent3e15f8e0f5cc87507f77546d92435670644dbd18 (diff)
Merge branch 'master' into focal-point-cropping
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py23
1 files changed, 19 insertions, 4 deletions
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..7511e1dc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")
@@ -34,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
@@ -70,3 +81,7 @@ def autocast(disable=False):
return contextlib.nullcontext()
return torch.autocast("cuda")
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)