aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorC43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com>2022-09-13 14:29:56 +0300
committerGitHub <noreply@github.com>2022-09-13 14:29:56 +0300
commit3b1b1444d4d90415fb42252406437b3d2ceb2110 (patch)
treec5b116e5a7ce5e63db13eba3f3ce753576d93ead
parentc84e333622883be7f1c2efa08259d36d27cadec7 (diff)
Complete cross attention update
-rw-r--r--modules/sd_hijack.py74
1 files changed, 73 insertions, 1 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index c058ac6e..ec7d14cb 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -11,7 +11,7 @@ from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention
-
+import ldm.modules.diffusionmodules.model
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
@@ -100,6 +100,76 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
+def nonlinearity_hijack(x):
+ # swish
+ t = torch.sigmoid(x)
+ x *= t
+ del t
+
+ return x
+
+def cross_attention_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_)
+ k1 = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q1.shape
+
+ q2 = q1.reshape(b, c, h*w)
+ del q1
+
+ q = q2.permute(0, 2, 1) # b,hw,c
+ del q2
+
+ k = k1.reshape(b, c, h*w) # b,c,hw
+ del k1
+
+ h_ = torch.zeros_like(k, device=q.device)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w2 = w1 * (int(c)**(-0.5))
+ del w1
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
+ del w2
+
+ # attend to values
+ v1 = v.reshape(b, c, h*w)
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ del w3
+
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ del v1, w4
+
+ h2 = h_.reshape(b, c, h, w)
+ del h_
+
+ h3 = self.proj_out(h2)
+ del h2
+
+ h3 += x
+
+ return h3
class StableDiffusionModelHijack:
ids_lookup = {}
@@ -175,6 +245,8 @@ class StableDiffusionModelHijack:
if cmd_opts.opt_split_attention:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1