aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorC43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com>2022-10-07 05:21:49 +0300
committerGitHub <noreply@github.com>2022-10-07 05:21:49 +0300
commitf174fb29228a04955fb951b32b0bab79e33ec2b8 (patch)
treeebc92c1e26e76c2ce095905bb1d26b5be825fc6e
parent2995107fa24cfd72b0a991e18271dcde148c2807 (diff)
add xformers attention
-rw-r--r--modules/sd_hijack_optimizations.py39
1 files changed, 38 insertions, 1 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index ea4cfdfc..da1b76e1 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,9 @@
import math
import torch
from torch import einsum
-
+import xformers.ops
+import functorch
+xformers._is_functorch_available=True
from ldm.util import default
from einops import rearrange
@@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
+def _maybe_init(self, x):
+ """
+ Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
+ : B, Head, Length
+ """
+ if self.attention_op is not None:
+ return
+ _, M, K = x.shape
+ try:
+ self.attention_op = xformers.ops.AttentionOpDispatch(
+ dtype=x.dtype,
+ device=x.device,
+ k=K,
+ attn_bias_type=type(None),
+ has_dropout=False,
+ kv_len=M,
+ q_len=M,
+ ).op
+ except NotImplementedError as err:
+ raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
+
+def xformers_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+ self._maybe_init(q)
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)