aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-11 07:45:05 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-11 07:45:05 +0300
commite334758ec281eaf7723c806713721d12bb568e24 (patch)
tree1f34358bb006da9aa4baee64aaecec2bdfd333b3 /modules
parentc9e5b921061d842ef64efcf50431253b3002e1ed (diff)
repair #10266
Diffstat (limited to 'modules')
-rw-r--r--modules/sub_quadratic_attention.py18
1 files changed, 5 insertions, 13 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index f80c1600..cc38debd 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -201,23 +201,15 @@ def efficient_dot_product_attention(
key=key,
value=value,
)
-
- # slices of res tensor are mutable, modifications made
- # to the slices will affect the original tensor.
- # if output of compute_query_chunk_attn function has same number of
- # dimensions as input query tensor, we initialize tensor like this:
- num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
- query_shape = get_query_chunk(0).shape
- res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
- res_dtype = get_query_chunk(0).dtype
- res = torch.zeros(res_shape, dtype=res_dtype)
-
- for i in range(num_query_chunks):
+
+ res = torch.zeros_like(query)
+ for i in range(math.ceil(q_tokens / query_chunk_size)):
attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
)
- res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores
+
+ res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
return res