aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-09-18 07:28:53 +0300
committerGitHub <noreply@github.com>2022-09-18 07:28:53 +0300
commit17b60490fa59ab0b5577b6ebc1d231c9e89e3710 (patch)
treecc8622bb42ef7a7d055b70cc71385789e8c8bef4
parent8ff6f093206111940e2601187f3f208d761543d6 (diff)
parent18d6fe4346e2543522cd2a64c71207e45632a46b (diff)
Merge pull request #635 from C43H66N12O12S2/attention
Move scale multiplication to the front
-rw-r--r--modules/sd_hijack.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 65414518..c4450ce4 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -50,7 +50,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- k_in = self.to_k(context)
+ k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
@@ -85,7 +85,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1