aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorC43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com>2022-10-08 04:09:18 +0300
committerGitHub <noreply@github.com>2022-10-08 04:09:18 +0300
commitc9cc65b201679ea43c763b0d85e749d40bbc5433 (patch)
treec8d009be7f85f00d6751b64b7d7c738770eca549
parent5e3ff846c56dc8e1d5c76ea04a8f2f74d7da07fc (diff)
switch to the proper way of calling xformers
-rw-r--r--modules/sd_hijack_optimizations.py28
1 files changed, 3 insertions, 25 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index da1b76e1..7fb4a45e 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -94,39 +94,17 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
-def _maybe_init(self, x):
- """
- Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
- : B, Head, Length
- """
- if self.attention_op is not None:
- return
- _, M, K = x.shape
- try:
- self.attention_op = xformers.ops.AttentionOpDispatch(
- dtype=x.dtype,
- device=x.device,
- k=K,
- attn_bias_type=type(None),
- has_dropout=False,
- kv_len=M,
- q_len=M,
- ).op
- except NotImplementedError as err:
- raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
-
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- self._maybe_init(q)
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
def cross_attention_attnblock_forward(self, x):