aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-12-27 08:50:55 -0500
committerbrkirch <brkirch@users.noreply.github.com>2023-01-06 00:14:13 -0500
commitd782a95967c9eea753df3333cd1954b6ec73eba0 (patch)
tree00e368f428916688518c3171bca8c01b92f4e549 /modules
parent4af3ca5393151d61363c30eef4965e694eeac15e (diff)
Add Birch-san's sub-quadratic attention implementation
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_hijack.py15
-rw-r--r--modules/sd_hijack_optimizations.py124
-rw-r--r--modules/shared.py4
-rw-r--r--modules/sub_quadratic_attention.py201
4 files changed, 310 insertions, 34 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 690a9ec2..019a6f3f 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
-from modules.sd_hijack_optimizations import invokeAI_mps_available
-
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
@@ -40,17 +38,16 @@ def apply_optimizations():
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ elif cmd_opts.opt_sub_quad_attention:
+ print("Applying sub-quadratic cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
- if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- else:
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 02c87f40..f5c153e8 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,7 @@
import math
import sys
import traceback
-import importlib
+import psutil
import torch
from torch import einsum
@@ -12,6 +12,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
+from .sub_quadratic_attention import efficient_dot_product_attention
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
print(traceback.format_exc(), file=sys.stderr)
+def get_available_vram():
+ if shared.device.type == 'cuda':
+ stats = torch.cuda.memory_stats(shared.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ return mem_free_total
+ else:
+ return psutil.virtual_memory().available
+
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
@@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
+ mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
-def check_for_psutil():
- try:
- spec = importlib.util.find_spec('psutil')
- return spec is not None
- except ModuleNotFoundError:
- return False
-
-invokeAI_mps_available = check_for_psutil()
-
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
-if invokeAI_mps_available:
- import psutil
- mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
@@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
+# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
+def sub_quad_attention_forward(self, x, context=None, mask=None):
+ assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
+
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+
+ x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
+
+ out_proj, dropout = self.to_out
+ x = out_proj(x)
+ x = dropout(x)
+
+ return x
+
+def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True):
+ bytes_per_token = torch.finfo(q.dtype).bits//8
+ batch_x_heads, q_tokens, _ = q.shape
+ _, k_tokens, _ = k.shape
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
+
+ available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+
+ if chunk_threshold_bytes is None:
+ chunk_threshold_bytes = available_vram
+ elif chunk_threshold_bytes == 0:
+ chunk_threshold_bytes = None
+
+ if kv_chunk_size_min is None:
+ kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
+ elif kv_chunk_size_min == 0:
+ kv_chunk_size_min = None
+
+ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
+ # i.e. send it down the unchunked fast-path
+ query_chunk_size = q_tokens
+ kv_chunk_size = k_tokens
+
+ return efficient_dot_product_attention(
+ q,
+ k,
+ v,
+ query_chunk_size=q_chunk_size,
+ kv_chunk_size=kv_chunk_size,
+ kv_chunk_size_min = kv_chunk_size_min,
+ use_checkpoint=use_checkpoint,
+ )
+
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x):
h_ = torch.zeros_like(k, device=q.device)
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
+ mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
@@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
+def sub_quad_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
diff --git a/modules/shared.py b/modules/shared.py
index d4ddeea0..487a7792 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
+parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
+parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
+parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
new file mode 100644
index 00000000..b11dc1c7
--- /dev/null
+++ b/modules/sub_quadratic_attention.py
@@ -0,0 +1,201 @@
+# original source:
+# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
+# license:
+# unspecified
+# credit:
+# Amin Rezaei (original author)
+# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
+# implementation of:
+# Self-attention Does Not Need O(n2) Memory":
+# https://arxiv.org/abs/2112.05682v2
+
+from functools import partial
+import torch
+from torch import Tensor
+from torch.utils.checkpoint import checkpoint
+import math
+from typing import Optional, NamedTuple, Protocol, List
+
+def dynamic_slice(
+ x: Tensor,
+ starts: List[int],
+ sizes: List[int],
+) -> Tensor:
+ slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
+ return x[slicing]
+
+class AttnChunk(NamedTuple):
+ exp_values: Tensor
+ exp_weights_sum: Tensor
+ max_score: Tensor
+
+class SummarizeChunk(Protocol):
+ @staticmethod
+ def __call__(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ ) -> AttnChunk: ...
+
+class ComputeQueryChunkAttn(Protocol):
+ @staticmethod
+ def __call__(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ ) -> Tensor: ...
+
+def _summarize_chunk(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ scale: float,
+) -> AttnChunk:
+ attn_weights = torch.baddbmm(
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ query,
+ key.transpose(1,2),
+ alpha=scale,
+ beta=0,
+ )
+ max_score, _ = torch.max(attn_weights, -1, keepdim=True)
+ max_score = max_score.detach()
+ exp_weights = torch.exp(attn_weights - max_score)
+ exp_values = torch.bmm(exp_weights, value)
+ max_score = max_score.squeeze(-1)
+ return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
+def _query_chunk_attention(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ summarize_chunk: SummarizeChunk,
+ kv_chunk_size: int,
+) -> Tensor:
+ batch_x_heads, k_tokens, k_channels_per_head = key.shape
+ _, _, v_channels_per_head = value.shape
+
+ def chunk_scanner(chunk_idx: int) -> AttnChunk:
+ key_chunk = dynamic_slice(
+ key,
+ (0, chunk_idx, 0),
+ (batch_x_heads, kv_chunk_size, k_channels_per_head)
+ )
+ value_chunk = dynamic_slice(
+ value,
+ (0, chunk_idx, 0),
+ (batch_x_heads, kv_chunk_size, v_channels_per_head)
+ )
+ return summarize_chunk(query, key_chunk, value_chunk)
+
+ chunks: List[AttnChunk] = [
+ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
+ ]
+ acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
+ chunk_values, chunk_weights, chunk_max = acc_chunk
+
+ global_max, _ = torch.max(chunk_max, 0, keepdim=True)
+ max_diffs = torch.exp(chunk_max - global_max)
+ chunk_values *= torch.unsqueeze(max_diffs, -1)
+ chunk_weights *= max_diffs
+
+ all_values = chunk_values.sum(dim=0)
+ all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
+ return all_values / all_weights
+
+# TODO: refactor CrossAttention#get_attention_scores to share code with this
+def _get_attention_scores_no_kv_chunking(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ scale: float,
+) -> Tensor:
+ attn_scores = torch.baddbmm(
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ query,
+ key.transpose(1,2),
+ alpha=scale,
+ beta=0,
+ )
+ attn_probs = attn_scores.softmax(dim=-1)
+ del attn_scores
+ hidden_states_slice = torch.bmm(attn_probs, value)
+ return hidden_states_slice
+
+class ScannedChunk(NamedTuple):
+ chunk_idx: int
+ attn_chunk: AttnChunk
+
+def efficient_dot_product_attention(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ query_chunk_size=1024,
+ kv_chunk_size: Optional[int] = None,
+ kv_chunk_size_min: Optional[int] = None,
+ use_checkpoint=True,
+):
+ """Computes efficient dot-product attention given query, key, and value.
+ This is efficient version of attention presented in
+ https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
+ Args:
+ query: queries for calculating attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ key: keys for calculating attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ value: values to be used in attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ query_chunk_size: int: query chunks size
+ kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
+ kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
+ use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
+ Returns:
+ Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
+ """
+ batch_x_heads, q_tokens, q_channels_per_head = query.shape
+ _, k_tokens, _ = key.shape
+ scale = q_channels_per_head ** -0.5
+
+ kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
+ if kv_chunk_size_min is not None:
+ kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
+
+ def get_query_chunk(chunk_idx: int) -> Tensor:
+ return dynamic_slice(
+ query,
+ (0, chunk_idx, 0),
+ (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
+ )
+
+ summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
+ summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
+ compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
+ _get_attention_scores_no_kv_chunking,
+ scale=scale
+ ) if k_tokens <= kv_chunk_size else (
+ # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
+ partial(
+ _query_chunk_attention,
+ kv_chunk_size=kv_chunk_size,
+ summarize_chunk=summarize_chunk,
+ )
+ )
+
+ if q_tokens <= query_chunk_size:
+ # fast-path for when there's just 1 query chunk
+ return compute_query_chunk_attn(
+ query=query,
+ key=key,
+ value=value,
+ )
+
+ # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
+ # and pass slices to be mutated, instead of torch.cat()ing the returned slices
+ res = torch.cat([
+ compute_query_chunk_attn(
+ query=get_query_chunk(i * query_chunk_size),
+ key=key,
+ value=value,
+ ) for i in range(math.ceil(q_tokens / query_chunk_size))
+ ], dim=1)
+ return res