from __future__ import annotations import math import sys import traceback import psutil import torch from torch import einsum from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention from modules.hypernetworks import hypernetwork import ldm.modules.attention import ldm.modules.diffusionmodules.model diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward class SdOptimization: name: str = None label: str | None = None cmd_opt: str | None = None priority: int = 0 def title(self): if self.label is None: return self.name return f"{self.name} - {self.label}" def is_available(self): return True def apply(self): pass def undo(self): ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward class SdOptimizationXformers(SdOptimization): name = "xformers" cmd_opt = "xformers" priority = 100 def is_available(self): return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) def apply(self): ldm.modules.attention.CrossAttention.forward = xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward class SdOptimizationSdpNoMem(SdOptimization): name = "sdp-no-mem" label = "scaled dot product without memory efficient attention" cmd_opt = "opt_sdp_no_mem_attention" priority = 80 def is_available(self): return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward class SdOptimizationSdp(SdOptimizationSdpNoMem): name = "sdp" label = "scaled dot product" cmd_opt = "opt_sdp_attention" priority = 70 def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward class SdOptimizationSubQuad(SdOptimization): name = "sub-quadratic" cmd_opt = "opt_sub_quad_attention" priority = 10 def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward class SdOptimizationV1(SdOptimization): name = "V1" label = "original v1" cmd_opt = "opt_split_attention_v1" priority = 10 def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 class SdOptimizationInvokeAI(SdOptimization): name = "InvokeAI" cmd_opt = "opt_split_attention_invokeai" @property def priority(self): return 1000 if not torch.cuda.is_available() else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI class SdOptimizationDoggettx(SdOptimization): name = "Doggettx" cmd_opt = "opt_split_attention" priority = 90 def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def list_optimizers(res): res.extend([ SdOptimizationXformers(), SdOptimizationSdpNoMem(), SdOptimizationSdp(), SdOptimizationSubQuad(), SdOptimizationV1(), SdOptimizationInvokeAI(), SdOptimizationDoggettx(), ]) if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: import xformers.ops shared.xformers_available = True except Exception: print("Cannot import xformers", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) def get_available_vram(): if shared.device.type == 'cuda': stats = torch.cuda.memory_stats(shared.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch return mem_free_total else: return psutil.virtual_memory().available # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) del context, context_k, context_v, x q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) del q_in, k_in, v_in dtype = q.dtype if shared.opts.upcast_attn: q, k, v = q.float(), k.float(), v.float() with devices.without_autocast(disable=not shared.opts.upcast_attn): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], 2): end = i + 2 s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) s1 *= self.scale s2 = s1.softmax(dim=-1) del s1 r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) del s2 del q, k, v r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 return self.to_out(r2) # taken from https://github.com/Doggettx/stable-diffusion and modified def split_cross_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) dtype = q_in.dtype if shared.opts.upcast_attn: q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() with devices.without_autocast(disable=not shared.opts.upcast_attn): k_in = k_in * self.scale del context, x q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) del q_in, k_in, v_in r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) mem_free_total = get_available_vram() gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() modifier = 3 if q.element_size() == 2 else 2.5 mem_required = tensor_size * modifier steps = 1 if mem_required > mem_free_total: steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 del q, k, v r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 return self.to_out(r2) # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) return einsum('b i j, b j d -> b i d', s, v) def einsum_op_slice_0(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) return r def einsum_op_slice_1(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) return r def einsum_op_mps_v1(q, k, v): if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) if slice_size % 4096 == 0: slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) def einsum_op_mps_v2(q, k, v): if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) def einsum_op_tensor_mem(q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: return einsum_op_compvis(q, k, v) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: return einsum_op_slice_0(q, k, v, q.shape[0] // div) return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) def einsum_op_cuda(q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) def einsum_op(q, k, v): if q.device.type == 'cuda': return einsum_op_cuda(q, k, v) if q.device.type == 'mps': if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: return einsum_op_mps_v1(q, k, v) return einsum_op_mps_v2(q, k, v) # Smaller slices are faster due to L2/L3/SLC caches. # Tested on i7 with 8MB L3 cache. return einsum_op_tensor_mem(q, k, v, 32) def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x dtype = q.dtype if shared.opts.upcast_attn: q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() with devices.without_autocast(disable=not shared.opts.upcast_attn): k = k * self.scale q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) r = einsum_op(q, k, v) r = r.to(dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) # -- End of code from https://github.com/invoke-ai/InvokeAI -- # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface def sub_quad_attention_forward(self, x, context=None, mask=None): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads q = self.to_q(x) context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) if q.device.type == 'mps': q, k, v = q.contiguous(), k.contiguous(), v.contiguous() dtype = q.dtype if shared.opts.upcast_attn: q, k = q.float(), k.float() x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) x = x.to(dtype) x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = self.to_out x = out_proj(x) x = dropout(x) return x def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape _, k_tokens, _ = k.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens if chunk_threshold is None: chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) elif chunk_threshold == 0: chunk_threshold_bytes = None else: chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) if kv_chunk_size_min is None and chunk_threshold_bytes is not None: kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) elif kv_chunk_size_min == 0: kv_chunk_size_min = None if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: # the big matmul fits into our memory limit; do everything in 1 chunk, # i.e. send it down the unchunked fast-path kv_chunk_size = k_tokens with devices.without_autocast(disable=q.dtype == v.dtype): return sub_quadratic_attention.efficient_dot_product_attention( q, k, v, query_chunk_size=q_chunk_size, kv_chunk_size=kv_chunk_size, kv_chunk_size_min = kv_chunk_size_min, use_checkpoint=use_checkpoint, ) def get_xformers_flash_attention_op(q, k, v): if not shared.cmd_opts.xformers_flash_attention: return None try: flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp fw, bw = flash_attention_op if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): return flash_attention_op except Exception as e: errors.display_once(e, "enabling flash attention") return None def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in)) del q_in, k_in, v_in dtype = q.dtype if shared.opts.upcast_attn: q, k, v = q.float(), k.float(), v.float() out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = out.to(dtype) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface def scaled_dot_product_attention_forward(self, x, context=None, mask=None): batch_size, sequence_length, inner_dim = x.shape if mask is not None: mask = self.prepare_attention_mask(mask, sequence_length, batch_size) mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) h = self.heads q_in = self.to_q(x) context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) head_dim = inner_dim // h q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) del q_in, k_in, v_in dtype = q.dtype if shared.opts.upcast_attn: q, k, v = q.float(), k.float(), v.float() # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) hidden_states = hidden_states.to(dtype) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) return hidden_states def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) q1 = self.q(h_) k1 = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q1.shape q2 = q1.reshape(b, c, h*w) del q1 q = q2.permute(0, 2, 1) # b,hw,c del q2 k = k1.reshape(b, c, h*w) # b,c,hw del k1 h_ = torch.zeros_like(k, device=q.device) mem_free_total = get_available_vram() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 steps = 1 if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w2 = w1 * (int(c)**(-0.5)) del w1 w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) del w2 # attend to values v1 = v.reshape(b, c, h*w) w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) del w3 h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] del v1, w4 h2 = h_.reshape(b, c, h, w) del h_ h3 = self.proj_out(h2) del h2 h3 += x return h3 def xformers_attnblock_forward(self, x): try: h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) dtype = q.dtype if shared.opts.upcast_attn: q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) def sdp_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) dtype = q.dtype if shared.opts.upcast_attn: q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out def sdp_no_mem_attnblock_forward(self, x): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return sdp_attnblock_forward(self, x) def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) q = q.contiguous() k = k.contiguous() v = v.contiguous() out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out