aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-05-11 07:21:18 +0300
committerGitHub <noreply@github.com>2023-05-11 07:21:18 +0300
commitc9e5b921061d842ef64efcf50431253b3002e1ed (patch)
tree92723cd92da2d1557571778b1c44c81182eb8ea4 /modules
parent8aa87c564a79965013715d56a5f90d2a34d5d6ee (diff)
parentc8732dfa6f763332962d97ff040af156e24a9e62 (diff)
Merge pull request #10266 from nero-dv/dev
Update sub_quadratic_attention.py
Diffstat (limited to 'modules')
-rw-r--r--modules/sub_quadratic_attention.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 05595323..f80c1600 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -202,13 +202,22 @@ def efficient_dot_product_attention(
value=value,
)
- # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
- # and pass slices to be mutated, instead of torch.cat()ing the returned slices
- res = torch.cat([
- compute_query_chunk_attn(
+ # slices of res tensor are mutable, modifications made
+ # to the slices will affect the original tensor.
+ # if output of compute_query_chunk_attn function has same number of
+ # dimensions as input query tensor, we initialize tensor like this:
+ num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
+ query_shape = get_query_chunk(0).shape
+ res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
+ res_dtype = get_query_chunk(0).dtype
+ res = torch.zeros(res_shape, dtype=res_dtype)
+
+ for i in range(num_query_chunks):
+ attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
- ) for i in range(math.ceil(q_tokens / query_chunk_size))
- ], dim=1)
+ )
+ res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores
+
return res