aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/xpu_specific.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index f7687a66..4e11125b 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
# cast to same dtype first
key = key.to(query.dtype)
value = value.to(query.dtype)
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(query.dtype)
N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length