aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/sub_quadratic_attention.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 93381bae..55052815 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -15,14 +15,9 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
-
-try:
- from typing import Protocol
-except:
- from typing_extensions import Protocol
-
from typing import Optional, NamedTuple, List
+
def narrow_trunc(
input: Tensor,
dim: int,
@@ -31,12 +26,14 @@ def narrow_trunc(
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
+
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
-class SummarizeChunk(Protocol):
+
+class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
@@ -44,7 +41,8 @@ class SummarizeChunk(Protocol):
value: Tensor,
) -> AttnChunk: ...
-class ComputeQueryChunkAttn(Protocol):
+
+class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
@@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
value: Tensor,
) -> Tensor: ...
+
def _summarize_chunk(
query: Tensor,
key: Tensor,
@@ -72,6 +71,7 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
def _query_chunk_attention(
query: Tensor,
key: Tensor,
@@ -112,6 +112,7 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
@@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
+
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
+
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,