From c8732dfa6f763332962d97ff040af156e24a9e62 Mon Sep 17 00:00:00 2001 From: Louis Del Valle <92354925+nero-dv@users.noreply.github.com> Date: Wed, 10 May 2023 22:05:18 -0500 Subject: Update sub_quadratic_attention.py 1. Determine the number of query chunks. 2. Calculate the final shape of the res tensor. 3. Initialize the tensor with the calculated shape and dtype, (same dtype as the input tensors, usually) Can initialize the tensor as a zero-filled tensor with the correct shape and dtype, then compute the attention scores for each query chunk and fill the corresponding slice of tensor. --- modules/sub_quadratic_attention.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 05595323..f80c1600 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -202,13 +202,22 @@ def efficient_dot_product_attention( value=value, ) - # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, - # and pass slices to be mutated, instead of torch.cat()ing the returned slices - res = torch.cat([ - compute_query_chunk_attn( + # slices of res tensor are mutable, modifications made + # to the slices will affect the original tensor. + # if output of compute_query_chunk_attn function has same number of + # dimensions as input query tensor, we initialize tensor like this: + num_query_chunks = int(np.ceil(q_tokens / query_chunk_size)) + query_shape = get_query_chunk(0).shape + res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:]) + res_dtype = get_query_chunk(0).dtype + res = torch.zeros(res_shape, dtype=res_dtype) + + for i in range(num_query_chunks): + attn_scores = compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key=key, value=value, - ) for i in range(math.ceil(q_tokens / query_chunk_size)) - ], dim=1) + ) + res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores + return res -- cgit v1.2.1 From e334758ec281eaf7723c806713721d12bb568e24 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 11 May 2023 07:45:05 +0300 Subject: repair #10266 --- modules/sub_quadratic_attention.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index f80c1600..cc38debd 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -201,23 +201,15 @@ def efficient_dot_product_attention( key=key, value=value, ) - - # slices of res tensor are mutable, modifications made - # to the slices will affect the original tensor. - # if output of compute_query_chunk_attn function has same number of - # dimensions as input query tensor, we initialize tensor like this: - num_query_chunks = int(np.ceil(q_tokens / query_chunk_size)) - query_shape = get_query_chunk(0).shape - res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:]) - res_dtype = get_query_chunk(0).dtype - res = torch.zeros(res_shape, dtype=res_dtype) - - for i in range(num_query_chunks): + + res = torch.zeros_like(query) + for i in range(math.ceil(q_tokens / query_chunk_size)): attn_scores = compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key=key, value=value, ) - res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores + + res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores return res -- cgit v1.2.1 From 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 18:28:15 +0300 Subject: Autofix Ruff W (not W605) (mostly whitespace) --- modules/sub_quadratic_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index cc38debd..497568eb 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -179,7 +179,7 @@ def efficient_dot_product_attention( chunk_idx, min(query_chunk_size, q_tokens) ) - + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( -- cgit v1.2.1