aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py173
1 files changed, 146 insertions, 27 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index f10865cd..2ec0b049 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import math
import sys
import traceback
@@ -9,10 +10,129 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, errors, devices
+from modules import shared, errors, devices, sub_quadratic_attention
from modules.hypernetworks import hypernetwork
-from .sub_quadratic_attention import efficient_dot_product_attention
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.model
+
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+
+class SdOptimization:
+ name: str = None
+ label: str | None = None
+ cmd_opt: str | None = None
+ priority: int = 0
+
+ def title(self):
+ if self.label is None:
+ return self.name
+
+ return f"{self.name} - {self.label}"
+
+ def is_available(self):
+ return True
+
+ def apply(self):
+ pass
+
+ def undo(self):
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
+
+class SdOptimizationXformers(SdOptimization):
+ name = "xformers"
+ cmd_opt = "xformers"
+ priority = 100
+
+ def is_available(self):
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+
+
+class SdOptimizationSdpNoMem(SdOptimization):
+ name = "sdp-no-mem"
+ label = "scaled dot product without memory efficient attention"
+ cmd_opt = "opt_sdp_no_mem_attention"
+ priority = 90
+
+ def is_available(self):
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+
+
+class SdOptimizationSdp(SdOptimizationSdpNoMem):
+ name = "sdp"
+ label = "scaled dot product"
+ cmd_opt = "opt_sdp_attention"
+ priority = 80
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+
+
+class SdOptimizationSubQuad(SdOptimization):
+ name = "sub-quadratic"
+ cmd_opt = "opt_sub_quad_attention"
+ priority = 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+
+
+class SdOptimizationV1(SdOptimization):
+ name = "V1"
+ label = "original v1"
+ cmd_opt = "opt_split_attention_v1"
+ priority = 10
+
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+
+
+class SdOptimizationInvokeAI(SdOptimization):
+ name = "InvokeAI"
+ cmd_opt = "opt_split_attention_invokeai"
+
+ @property
+ def priority(self):
+ return 1000 if not torch.cuda.is_available() else 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+
+
+class SdOptimizationDoggettx(SdOptimization):
+ name = "Doggettx"
+ cmd_opt = "opt_split_attention"
+ priority = 20
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+
+
+def list_optimizers(res):
+ res.extend([
+ SdOptimizationXformers(),
+ SdOptimizationSdpNoMem(),
+ SdOptimizationSdp(),
+ SdOptimizationSubQuad(),
+ SdOptimizationV1(),
+ SdOptimizationInvokeAI(),
+ SdOptimizationDoggettx(),
+ ])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -49,7 +169,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
v_in = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -62,10 +182,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
-
+
s2 = s1.softmax(dim=-1)
del s1
-
+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
del q, k, v
@@ -95,43 +215,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale
-
+
del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
-
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
+
mem_free_total = get_available_vram()
-
+
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
-
+
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
+
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
-
+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
-
+
del q, k, v
r1 = r1.to(dtype)
@@ -228,8 +348,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@@ -296,11 +416,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
- query_chunk_size = q_tokens
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
+ return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,
@@ -335,7 +454,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -370,7 +489,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
-
+
del q_in, k_in, v_in
dtype = q.dtype
@@ -452,7 +571,7 @@ def cross_attention_attnblock_forward(self, x):
h3 += x
return h3
-
+
def xformers_attnblock_forward(self, x):
try:
h_ = x
@@ -461,7 +580,7 @@ def xformers_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -483,7 +602,7 @@ def sdp_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -507,7 +626,7 @@ def sub_quad_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()