From 5bb126bd89dd0fb87280472f472388a6f230c270 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 5 Sep 2022 01:41:20 +0300 Subject: add split attention layer optimization from https://github.com/basujindal/stable-diffusion/pull/117 --- modules/sd_hijack.py | 44 +++++++++++++++++++++++++++++++++++++++++++- modules/shared.py | 1 + 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6ee92e77..1dbdc9ce 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,8 +3,43 @@ import sys import traceback import torch import numpy as np +from torch import einsum -from modules.shared import opts, device +from modules.shared import opts, device, cmd_opts + +from ldm.util import default +from einops import rearrange +import ldm.modules.attention + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class StableDiffusionModelHijack: @@ -67,6 +102,9 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + if cmd_opts.opt_split_attention: + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): @@ -205,4 +243,8 @@ class EmbeddingsWithFixes(torch.nn.Module): return inputs_embeds + + + + model_hijack = StableDiffusionModelHijack() diff --git a/modules/shared.py b/modules/shared.py index 72e92eb9..dbfa7838 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -29,6 +29,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) +parser.add_argument("--opt-split-attention", type=str, help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance", default=os.path.join(script_path, 'ESRGAN')) cmd_opts = parser.parse_args() cpu = torch.device("cpu") -- cgit v1.2.1