aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 1325569c..f8cffae1 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -44,6 +44,15 @@ def get_optimal_device():
return cpu
+def get_device_for(task):
+ from modules import shared
+
+ if task in shared.cmd_opts.use_cpu:
+ return cpu
+
+ return get_optimal_device()
+
+
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
@@ -53,12 +62,12 @@ def torch_gc():
def enable_tf32():
if torch.cuda.is_available():
- for devid in range(0,torch.cuda.device_count()):
- if torch.cuda.get_device_capability(devid) == (7, 5):
- shd = True
- if shd:
+
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+ if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
torch.backends.cudnn.benchmark = True
- torch.backends.cudnn.enabled = True
+
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@@ -67,7 +76,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16