aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/xpu_specific.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index 4e11125b..1137891a 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention(
return torch.reshape(result, (*N, L, Ev))
+def is_xpu_device(device: str | torch.device = None):
+ if device is None:
+ return False
+ if isinstance(device, str):
+ return device.startswith("xpu")
+ return device.type == "xpu"
+
+
if has_xpu:
- # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
- CondFunc('torch.Generator',
- lambda orig_func, device=None: torch.xpu.Generator(device),
- lambda orig_func, device=None: device is not None and device.type == "xpu")
+ try:
+ # torch.Generator supports "xpu" device since 2.1
+ torch.Generator("xpu")
+ except:
+ # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1)
+ CondFunc('torch.Generator',
+ lambda orig_func, device=None: torch.xpu.Generator(device),
+ lambda orig_func, device=None: is_xpu_device(device))
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',