diff options
Diffstat (limited to 'modules/xpu_specific.py')
-rw-r--r-- | modules/xpu_specific.py | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66..4e11125b 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length |