aboutsummaryrefslogtreecommitdiff
path: root/modules/hypertile.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypertile.py')
-rw-r--r--modules/hypertile.py371
1 files changed, 0 insertions, 371 deletions
diff --git a/modules/hypertile.py b/modules/hypertile.py
deleted file mode 100644
index be898fce..00000000
--- a/modules/hypertile.py
+++ /dev/null
@@ -1,371 +0,0 @@
-"""
-Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
-Warn : The patch works well only if the input image has a width and height that are multiples of 128
-Author : @tfernd Github : https://github.com/tfernd/HyperTile
-"""
-
-from __future__ import annotations
-from typing import Callable
-from typing_extensions import Literal
-
-import logging
-from functools import wraps, cache
-from contextlib import contextmanager
-
-import math
-import torch.nn as nn
-import random
-
-from einops import rearrange
-
-# TODO add SD-XL layers
-DEPTH_LAYERS = {
- 0: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.1.1.transformer_blocks.0.attn1",
- "input_blocks.2.1.transformer_blocks.0.attn1",
- "output_blocks.9.1.transformer_blocks.0.attn1",
- "output_blocks.10.1.transformer_blocks.0.attn1",
- "output_blocks.11.1.transformer_blocks.0.attn1",
- # SD 1.5 VAE
- "decoder.mid_block.attentions.0",
- "decoder.mid.attn_1",
- ],
- 1: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.4.1.transformer_blocks.0.attn1",
- "input_blocks.5.1.transformer_blocks.0.attn1",
- "output_blocks.6.1.transformer_blocks.0.attn1",
- "output_blocks.7.1.transformer_blocks.0.attn1",
- "output_blocks.8.1.transformer_blocks.0.attn1",
- ],
- 2: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.7.1.transformer_blocks.0.attn1",
- "input_blocks.8.1.transformer_blocks.0.attn1",
- "output_blocks.3.1.transformer_blocks.0.attn1",
- "output_blocks.4.1.transformer_blocks.0.attn1",
- "output_blocks.5.1.transformer_blocks.0.attn1",
- ],
- 3: [
- # SD 1.5 U-Net (diffusers)
- "mid_block.attentions.0.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "middle_block.1.transformer_blocks.0.attn1",
- ],
-}
-# XL layers, thanks for GitHub@gel-crabs for the help
-DEPTH_LAYERS_XL = {
- 0: [
- # SD 1.5 U-Net (diffusers)
- "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
- "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
- "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.4.1.transformer_blocks.0.attn1",
- "input_blocks.5.1.transformer_blocks.0.attn1",
- "output_blocks.3.1.transformer_blocks.0.attn1",
- "output_blocks.4.1.transformer_blocks.0.attn1",
- "output_blocks.5.1.transformer_blocks.0.attn1",
- # SD 1.5 VAE
- "decoder.mid_block.attentions.0",
- "decoder.mid.attn_1",
- ],
- 1: [
- # SD 1.5 U-Net (diffusers)
- #"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
- #"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
- #"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
- #"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
- #"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "input_blocks.4.1.transformer_blocks.1.attn1",
- "input_blocks.5.1.transformer_blocks.1.attn1",
- "output_blocks.3.1.transformer_blocks.1.attn1",
- "output_blocks.4.1.transformer_blocks.1.attn1",
- "output_blocks.5.1.transformer_blocks.1.attn1",
- "input_blocks.7.1.transformer_blocks.0.attn1",
- "input_blocks.8.1.transformer_blocks.0.attn1",
- "output_blocks.0.1.transformer_blocks.0.attn1",
- "output_blocks.1.1.transformer_blocks.0.attn1",
- "output_blocks.2.1.transformer_blocks.0.attn1",
- "input_blocks.7.1.transformer_blocks.1.attn1",
- "input_blocks.8.1.transformer_blocks.1.attn1",
- "output_blocks.0.1.transformer_blocks.1.attn1",
- "output_blocks.1.1.transformer_blocks.1.attn1",
- "output_blocks.2.1.transformer_blocks.1.attn1",
- "input_blocks.7.1.transformer_blocks.2.attn1",
- "input_blocks.8.1.transformer_blocks.2.attn1",
- "output_blocks.0.1.transformer_blocks.2.attn1",
- "output_blocks.1.1.transformer_blocks.2.attn1",
- "output_blocks.2.1.transformer_blocks.2.attn1",
- "input_blocks.7.1.transformer_blocks.3.attn1",
- "input_blocks.8.1.transformer_blocks.3.attn1",
- "output_blocks.0.1.transformer_blocks.3.attn1",
- "output_blocks.1.1.transformer_blocks.3.attn1",
- "output_blocks.2.1.transformer_blocks.3.attn1",
- "input_blocks.7.1.transformer_blocks.4.attn1",
- "input_blocks.8.1.transformer_blocks.4.attn1",
- "output_blocks.0.1.transformer_blocks.4.attn1",
- "output_blocks.1.1.transformer_blocks.4.attn1",
- "output_blocks.2.1.transformer_blocks.4.attn1",
- "input_blocks.7.1.transformer_blocks.5.attn1",
- "input_blocks.8.1.transformer_blocks.5.attn1",
- "output_blocks.0.1.transformer_blocks.5.attn1",
- "output_blocks.1.1.transformer_blocks.5.attn1",
- "output_blocks.2.1.transformer_blocks.5.attn1",
- "input_blocks.7.1.transformer_blocks.6.attn1",
- "input_blocks.8.1.transformer_blocks.6.attn1",
- "output_blocks.0.1.transformer_blocks.6.attn1",
- "output_blocks.1.1.transformer_blocks.6.attn1",
- "output_blocks.2.1.transformer_blocks.6.attn1",
- "input_blocks.7.1.transformer_blocks.7.attn1",
- "input_blocks.8.1.transformer_blocks.7.attn1",
- "output_blocks.0.1.transformer_blocks.7.attn1",
- "output_blocks.1.1.transformer_blocks.7.attn1",
- "output_blocks.2.1.transformer_blocks.7.attn1",
- "input_blocks.7.1.transformer_blocks.8.attn1",
- "input_blocks.8.1.transformer_blocks.8.attn1",
- "output_blocks.0.1.transformer_blocks.8.attn1",
- "output_blocks.1.1.transformer_blocks.8.attn1",
- "output_blocks.2.1.transformer_blocks.8.attn1",
- "input_blocks.7.1.transformer_blocks.9.attn1",
- "input_blocks.8.1.transformer_blocks.9.attn1",
- "output_blocks.0.1.transformer_blocks.9.attn1",
- "output_blocks.1.1.transformer_blocks.9.attn1",
- "output_blocks.2.1.transformer_blocks.9.attn1",
- ],
- 2: [
- # SD 1.5 U-Net (diffusers)
- "mid_block.attentions.0.transformer_blocks.0.attn1",
- # SD 1.5 U-Net (ldm)
- "middle_block.1.transformer_blocks.0.attn1",
- "middle_block.1.transformer_blocks.1.attn1",
- "middle_block.1.transformer_blocks.2.attn1",
- "middle_block.1.transformer_blocks.3.attn1",
- "middle_block.1.transformer_blocks.4.attn1",
- "middle_block.1.transformer_blocks.5.attn1",
- "middle_block.1.transformer_blocks.6.attn1",
- "middle_block.1.transformer_blocks.7.attn1",
- "middle_block.1.transformer_blocks.8.attn1",
- "middle_block.1.transformer_blocks.9.attn1",
- ],
- 3 : [] # TODO - separate layers for SD-XL
-}
-
-
-RNG_INSTANCE = random.Random()
-
-def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
- """
- Returns a random divisor of value that
- x * min_value <= value
- if max_options is 1, the behavior is deterministic
- """
- min_value = min(min_value, value)
-
- # All big divisors of value (inclusive)
- divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
-
- ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
-
- idx = RNG_INSTANCE.randint(0, len(ns) - 1)
-
- return ns[idx]
-
-def set_hypertile_seed(seed: int) -> None:
- RNG_INSTANCE.seed(seed)
-
-def largest_tile_size_available(width:int, height:int) -> int:
- """
- Calculates the largest tile size available for a given width and height
- Tile size is always a power of 2
- """
- gcd = math.gcd(width, height)
- largest_tile_size_available = 1
- while gcd % (largest_tile_size_available * 2) == 0:
- largest_tile_size_available *= 2
- return largest_tile_size_available
-
-def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
- """
- Finds h and w such that h*w = hw and h/w = aspect_ratio
- We check all possible divisors of hw and return the closest to the aspect ratio
- """
- divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
- pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
- ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
- closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
- closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
- return closest_pair
-
-@cache
-def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
- """
- Finds h and w such that h*w = hw and h/w = aspect_ratio
- """
- h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
- # find h and w such that h*w = hw and h/w = aspect_ratio
- if h * w != hw:
- w_candidate = hw / h
- # check if w is an integer
- if not w_candidate.is_integer():
- h_candidate = hw / w
- # check if h is an integer
- if not h_candidate.is_integer():
- return iterative_closest_divisors(hw, aspect_ratio)
- else:
- h = int(h_candidate)
- else:
- w = int(w_candidate)
- return h, w
-
-@contextmanager
-def split_attention(
- layer: nn.Module,
- /,
- aspect_ratio: float, # width/height
- tile_size: int = 128, # 128 for VAE
- swap_size: int = 1, # 1 for VAE
- *,
- disable: bool = False,
- max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1
- scale_depth: bool = True, # scale the tile-size depending on the depth
- is_sdxl: bool = False, # is the model SD-XL
-):
- # Hijacks AttnBlock from ldm and Attention from diffusers
-
- if disable:
- logging.info(f"Attention for {layer.__class__.__qualname__} not splitted")
- yield
- return
-
- latent_tile_size = max(128, tile_size) // 8
-
- def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable:
- @wraps(forward)
- def wrapper(*args, **kwargs):
- x = args[0]
-
- # VAE
- if x.ndim == 4:
- b, c, h, w = x.shape
-
- nh = random_divisor(h, latent_tile_size, swap_size)
- nw = random_divisor(w, latent_tile_size, swap_size)
-
- if nh * nw > 1:
- x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
-
- out = forward(x, *args[1:], **kwargs)
-
- if nh * nw > 1:
- out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
-
- # U-Net
- else:
- hw: int = x.size(1)
- h, w = find_hw_candidates(hw, aspect_ratio)
- assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
-
- factor = 2**depth if scale_depth else 1
- nh = random_divisor(h, latent_tile_size * factor, swap_size)
- nw = random_divisor(w, latent_tile_size * factor, swap_size)
-
- module._split_sizes_hypertile.append((nh, nw)) # type: ignore
-
- if nh * nw > 1:
- x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
-
- out = forward(x, *args[1:], **kwargs)
-
- if nh * nw > 1:
- out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
- out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
-
- return out
-
- return wrapper
-
- # Handle hijacking the forward method and recovering afterwards
- try:
- if is_sdxl:
- layers = DEPTH_LAYERS_XL
- else:
- layers = DEPTH_LAYERS
- for depth in range(max_depth + 1):
- for layer_name, module in layer.named_modules():
- if any(layer_name.endswith(try_name) for try_name in layers[depth]):
- # print input shape for debugging
- logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}")
- # hijack
- module._original_forward_hypertile = module.forward
- module.forward = self_attn_forward(module.forward, depth, layer_name, module)
- module._split_sizes_hypertile = []
- yield
- finally:
- for layer_name, module in layer.named_modules():
- # remove hijack
- if hasattr(module, "_original_forward_hypertile"):
- if module._split_sizes_hypertile:
- logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})")
- # recover
- module.forward = module._original_forward_hypertile
- del module._original_forward_hypertile
- del module._split_sizes_hypertile
-
-def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts):
- """
- Returns context manager for VAE
- """
- enabled = opts.hypertile_split_vae_attn
- swap_size = opts.hypertile_swap_size_vae
- max_depth = opts.hypertile_max_depth_vae
- tile_size_max = opts.hypertile_max_tile_vae
- return split_attention(
- model,
- aspect_ratio=aspect_ratio,
- tile_size=min(tile_size, tile_size_max),
- swap_size=swap_size,
- disable=not enabled,
- max_depth=max_depth,
- is_sdxl=False,
- )
-
-def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool):
- """
- Returns context manager for U-Net
- """
- enabled = opts.hypertile_split_unet_attn
- swap_size = opts.hypertile_swap_size_unet
- max_depth = opts.hypertile_max_depth_unet
- tile_size_max = opts.hypertile_max_tile_unet
- return split_attention(
- model,
- aspect_ratio=aspect_ratio,
- tile_size=min(tile_size, tile_size_max),
- swap_size=swap_size,
- disable=not enabled,
- max_depth=max_depth,
- is_sdxl=is_sdxl,
- )