From d782a95967c9eea753df3333cd1954b6ec73eba0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 27 Dec 2022 08:50:55 -0500 Subject: Add Birch-san's sub-quadratic attention implementation --- modules/sd_hijack_optimizations.py | 124 +++++++++++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 25 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 02c87f40..f5c153e8 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,7 @@ import math import sys import traceback -import importlib +import psutil import torch from torch import einsum @@ -12,6 +12,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from .sub_quadratic_attention import efficient_dot_product_attention + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: print(traceback.format_exc(), file=sys.stderr) +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads @@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() @@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def check_for_psutil(): - try: - spec = importlib.util.find_spec('psutil') - return spec is not None - except ModuleNotFoundError: - return False - -invokeAI_mps_available = check_for_psutil() - # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -if invokeAI_mps_available: - import psutil - mem_total_gb = psutil.virtual_memory().total // (1 << 30) +mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) @@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # -- End of code from https://github.com/invoke-ai/InvokeAI -- + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +def sub_quad_attention_forward(self, x, context=None, mask=None): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + + if chunk_threshold_bytes is None: + chunk_threshold_bytes = available_vram + elif chunk_threshold_bytes == 0: + chunk_threshold_bytes = None + + if kv_chunk_size_min is None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x): h_ = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 @@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out -- cgit v1.2.1