aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/xpu_specific.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index 0ebdd596..f7687a66 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -33,7 +33,7 @@ has_xpu = check_for_xpu()
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
# which is the best trade-off between VRAM usage and performance.
-ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
+ARC_SINGLE_ALLOCATION_LIMIT = {}
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
def torch_xpu_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
@@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention(
Ev = value.size(-1) # Embedding dimension of the value
total_batch_size = torch.numel(torch.empty(N))
- batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size()))
+ device_id = query.device.index
+ if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
+ ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
+ batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
if total_batch_size <= batch_size_limit:
return orig_sdp_attn_func(