aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-05 01:41:20 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-05 01:41:20 +0300
commit5bb126bd89dd0fb87280472f472388a6f230c270 (patch)
treec8f0a94fcd45ed37f052b8b7b3a8cac218dcfc2e
parent407fc1fe0c4471837817d49f25bea1df3ec84ec8 (diff)
add split attention layer optimization from https://github.com/basujindal/stable-diffusion/pull/117
-rw-r--r--modules/sd_hijack.py44
-rw-r--r--modules/shared.py1
2 files changed, 44 insertions, 1 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6ee92e77..1dbdc9ce 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,8 +3,43 @@ import sys
import traceback
import torch
import numpy as np
+from torch import einsum
-from modules.shared import opts, device
+from modules.shared import opts, device, cmd_opts
+
+from ldm.util import default
+from einops import rearrange
+import ldm.modules.attention
+
+
+# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
+def split_cross_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+
+ s2 = s1.softmax(dim=-1)
+ del s1
+
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
class StableDiffusionModelHijack:
@@ -67,6 +102,9 @@ class StableDiffusionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ if cmd_opts.opt_split_attention:
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
@@ -205,4 +243,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
return inputs_embeds
+
+
+
+
model_hijack = StableDiffusionModelHijack()
diff --git a/modules/shared.py b/modules/shared.py
index 72e92eb9..dbfa7838 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -29,6 +29,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
+parser.add_argument("--opt-split-attention", type=str, help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance", default=os.path.join(script_path, 'ESRGAN'))
cmd_opts = parser.parse_args()
cpu = torch.device("cpu")