aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py76
1 files changed, 73 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 1d4eb563..ff279ac5 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -4,10 +4,18 @@ from functools import lru_cache
import torch
from modules import errors, shared
+from modules import torch_utils
if sys.platform == "darwin":
from modules import mac_specific
+if shared.cmd_opts.use_ipex:
+ from modules import xpu_specific
+
+
+def has_xpu() -> bool:
+ return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
+
def has_mps() -> bool:
if sys.platform != "darwin":
@@ -16,6 +24,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
+def cuda_no_autocast(device_id=None) -> bool:
+ if device_id is None:
+ device_id = get_cuda_device_id()
+ return (
+ torch.cuda.get_device_capability(device_id) == (7, 5)
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
+ )
+
+
+def get_cuda_device_id():
+ return (
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ else 0
+ ) or torch.cuda.current_device()
+
+
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -30,6 +55,9 @@ def get_optimal_device_name():
if has_mps():
return "mps"
+ if has_xpu():
+ return xpu_specific.get_xpu_device_string()
+
return "cpu"
@@ -38,7 +66,7 @@ def get_optimal_device():
def get_device_for(task):
- if task in shared.cmd_opts.use_cpu:
+ if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
return cpu
return get_optimal_device()
@@ -54,14 +82,16 @@ def torch_gc():
if has_mps():
mac_specific.torch_mps_gc()
+ if has_xpu():
+ xpu_specific.torch_xpu_gc()
+
def enable_tf32():
if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
- if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
+ if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -71,6 +101,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
+fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -91,12 +122,51 @@ def cond_cast_float(input):
nv_rng = None
+patch_module_list = [
+ torch.nn.Linear,
+ torch.nn.Conv2d,
+ torch.nn.MultiheadAttention,
+ torch.nn.GroupNorm,
+ torch.nn.LayerNorm,
+]
+
+
+def manual_cast_forward(self, *args, **kwargs):
+ org_dtype = torch_utils.get_param(self).dtype
+ self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+ result = self.org_forward(*args, **kwargs)
+ self.to(org_dtype)
+ return result
+
+
+@contextlib.contextmanager
+def manual_cast():
+ for module_type in patch_module_list:
+ org_forward = module_type.forward
+ module_type.forward = manual_cast_forward
+ module_type.org_forward = org_forward
+ try:
+ yield None
+ finally:
+ for module_type in patch_module_list:
+ module_type.forward = module_type.org_forward
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
+ if fp8 and device==cpu:
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
+
+ if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
+ return manual_cast()
+
+ if has_mps() and shared.cmd_opts.precision != "full":
+ return manual_cast()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()