aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py292
1 files changed, 211 insertions, 81 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 02c87f40..c02d954c 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,7 @@
import math
import sys
import traceback
-import importlib
+import psutil
import torch
from torch import einsum
@@ -9,9 +9,11 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared
+from modules import shared, errors, devices
from modules.hypernetworks import hypernetwork
+from .sub_quadratic_attention import efficient_dot_product_attention
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
print(traceback.format_exc(), file=sys.stderr)
+def get_available_vram():
+ if shared.device.type == 'cuda':
+ stats = torch.cuda.memory_stats(shared.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ return mem_free_total
+ else:
+ return psutil.virtual_memory().available
+
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
@@ -29,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
del context, context_k, context_v, x
@@ -37,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
- s2 = s1.softmax(dim=-1)
- del s1
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+
+ s2 = s1.softmax(dim=-1)
+ del s1
+
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+ del q, k, v
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
- del q, k, v
+ r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -63,54 +85,56 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- k_in *= self.scale
-
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+ dtype = q_in.dtype
+ if shared.opts.upcast_attn:
+ q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k_in = k_in * self.scale
+
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+
+ mem_free_total = get_available_vram()
+
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
+
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+
+ del q, k, v
- del q, k, v
+ r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -118,19 +142,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
-def check_for_psutil():
- try:
- spec = importlib.util.find_spec('psutil')
- return spec is not None
- except ModuleNotFoundError:
- return False
-
-invokeAI_mps_available = check_for_psutil()
-
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
-if invokeAI_mps_available:
- import psutil
- mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
@@ -204,29 +217,131 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
- k = self.to_k(context_k) * self.scale
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r = einsum_op(q, k, v)
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
+
+ with devices.without_autocast(disable=not shared.opts.upcast_attn):
+ k = k * self.scale
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ r = einsum_op(q, k, v)
+ r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
+# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
+# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
+def sub_quad_attention_forward(self, x, context=None, mask=None):
+ assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
+
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+
+ x = x.to(dtype)
+
+ x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
+
+ out_proj, dropout = self.to_out
+ x = out_proj(x)
+ x = dropout(x)
+
+ return x
+
+def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
+ bytes_per_token = torch.finfo(q.dtype).bits//8
+ batch_x_heads, q_tokens, _ = q.shape
+ _, k_tokens, _ = k.shape
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
+
+ if chunk_threshold is None:
+ chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+ elif chunk_threshold == 0:
+ chunk_threshold_bytes = None
+ else:
+ chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
+
+ if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
+ kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
+ elif kv_chunk_size_min == 0:
+ kv_chunk_size_min = None
+
+ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
+ # i.e. send it down the unchunked fast-path
+ query_chunk_size = q_tokens
+ kv_chunk_size = k_tokens
+
+ with devices.without_autocast(disable=q.dtype == v.dtype):
+ return efficient_dot_product_attention(
+ q,
+ k,
+ v,
+ query_chunk_size=q_chunk_size,
+ kv_chunk_size=kv_chunk_size,
+ kv_chunk_size_min = kv_chunk_size_min,
+ use_checkpoint=use_checkpoint,
+ )
+
+
+def get_xformers_flash_attention_op(q, k, v):
+ if not shared.cmd_opts.xformers_flash_attention:
+ return None
+
+ try:
+ flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
+ fw, bw = flash_attention_op
+ if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
+ return flash_attention_op
+ except Exception as e:
+ errors.display_once(e, "enabling flash attention")
+
+ return None
+
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
+
+ out = out.to(dtype)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@@ -252,12 +367,7 @@ def cross_attention_attnblock_forward(self, x):
h_ = torch.zeros_like(k, device=q.device)
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
+ mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
@@ -303,12 +413,32 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
- out = xformers.ops.memory_efficient_attention(q, k, v)
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
+ out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
+def sub_quad_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out