aboutsummaryrefslogtreecommitdiff
path: root/modules/xpu_specific.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/xpu_specific.py')
-rw-r--r--modules/xpu_specific.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index d933c790..f7687a66 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -27,6 +27,71 @@ def torch_xpu_gc():
has_xpu = check_for_xpu()
+
+# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
+# Here we implement a slicing algorithm to split large batch size into smaller chunks,
+# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
+# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
+# which is the best trade-off between VRAM usage and performance.
+ARC_SINGLE_ALLOCATION_LIMIT = {}
+orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
+def torch_xpu_scaled_dot_product_attention(
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
+):
+ # cast to same dtype first
+ key = key.to(query.dtype)
+ value = value.to(query.dtype)
+
+ N = query.shape[:-2] # Batch size
+ L = query.size(-2) # Target sequence length
+ E = query.size(-1) # Embedding dimension of the query and key
+ S = key.size(-2) # Source sequence length
+ Ev = value.size(-1) # Embedding dimension of the value
+
+ total_batch_size = torch.numel(torch.empty(N))
+ device_id = query.device.index
+ if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
+ ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
+ batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
+
+ if total_batch_size <= batch_size_limit:
+ return orig_sdp_attn_func(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+
+ query = torch.reshape(query, (-1, L, E))
+ key = torch.reshape(key, (-1, S, E))
+ value = torch.reshape(value, (-1, S, Ev))
+ if attn_mask is not None:
+ attn_mask = attn_mask.view(-1, L, S)
+ chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
+ outputs = []
+ for i in range(chunk_count):
+ attn_mask_chunk = (
+ None
+ if attn_mask is None
+ else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
+ )
+ chunk_output = orig_sdp_attn_func(
+ query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ attn_mask_chunk,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+ outputs.append(chunk_output)
+ result = torch.cat(outputs, dim=0)
+ return torch.reshape(result, (*N, L, Ev))
+
+
if has_xpu:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
CondFunc('torch.Generator',
@@ -48,3 +113,12 @@ if has_xpu:
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+ CondFunc('torch.bmm',
+ lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
+ lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
+ CondFunc('torch.cat',
+ lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
+ lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
+ CondFunc('torch.nn.functional.scaled_dot_product_attention',
+ lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
+ lambda orig_func, query, *args, **kwargs: query.is_xpu)