aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/sub_quadratic_attention.py34
1 files changed, 19 insertions, 15 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index b11dc1c7..95924d24 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -5,6 +5,7 @@
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
+# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
@@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, Protocol, List
-def dynamic_slice(
- x: Tensor,
- starts: List[int],
- sizes: List[int],
+def narrow_trunc(
+ input: Tensor,
+ dim: int,
+ start: int,
+ length: int
) -> Tensor:
- slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
- return x[slicing]
+ return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
class AttnChunk(NamedTuple):
exp_values: Tensor
@@ -76,15 +77,17 @@ def _query_chunk_attention(
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
- key_chunk = dynamic_slice(
+ key_chunk = narrow_trunc(
key,
- (0, chunk_idx, 0),
- (batch_x_heads, kv_chunk_size, k_channels_per_head)
+ 1,
+ chunk_idx,
+ kv_chunk_size
)
- value_chunk = dynamic_slice(
+ value_chunk = narrow_trunc(
value,
- (0, chunk_idx, 0),
- (batch_x_heads, kv_chunk_size, v_channels_per_head)
+ 1,
+ chunk_idx,
+ kv_chunk_size
)
return summarize_chunk(query, key_chunk, value_chunk)
@@ -161,10 +164,11 @@ def efficient_dot_product_attention(
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
def get_query_chunk(chunk_idx: int) -> Tensor:
- return dynamic_slice(
+ return narrow_trunc(
query,
- (0, chunk_idx, 0),
- (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
+ 1,
+ chunk_idx,
+ min(query_chunk_size, q_tokens)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)