diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/modules/devices.py b/modules/devices.py index 1325569c..f8cffae1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -44,6 +44,15 @@ def get_optimal_device(): return cpu +def get_device_for(task): + from modules import shared + + if task in shared.cmd_opts.use_cpu: + return cpu + + return get_optimal_device() + + def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): @@ -53,12 +62,12 @@ def torch_gc(): def enable_tf32(): if torch.cuda.is_available(): - for devid in range(0,torch.cuda.device_count()): - if torch.cuda.get_device_capability(devid) == (7, 5): - shd = True - if shd: + + # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't + # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 + if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): torch.backends.cudnn.benchmark = True - torch.backends.cudnn.enabled = True + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -67,7 +76,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None +device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 |