aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorPam <pamhome21@gmail.com>2023-03-10 22:48:41 +0500
committerPam <pamhome21@gmail.com>2023-03-10 22:48:41 +0500
commit8d7fa2f67cb0554d8902d5d407166876020e067e (patch)
treee5d3c27a08b64bc840560ac290571097fda67b8c /modules/sd_hijack_optimizations.py
parent0981dea94832f34d638b1aa8964cfaeffd223b47 (diff)
sdp_attnblock_forward hijack
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 68b1dd84..2e307b5d 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -473,6 +473,30 @@ def xformers_attnblock_forward(self, x):
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+def sdp_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
+ out = out.to(dtype)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
+
+def sdp_no_mem_attnblock_forward(self, x):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ return sdp_attnblock_forward(self, x)
+
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)