aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/hypertile
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/hypertile')
-rw-r--r--extensions-builtin/hypertile/hypertile.py348
-rw-r--r--extensions-builtin/hypertile/scripts/hypertile_script.py73
2 files changed, 421 insertions, 0 deletions
diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py
new file mode 100644
index 00000000..a40c1311
--- /dev/null
+++ b/extensions-builtin/hypertile/hypertile.py
@@ -0,0 +1,348 @@
+"""
+Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
+Warn: The patch works well only if the input image has a width and height that are multiples of 128
+Original author: @tfernd Github: https://github.com/tfernd/HyperTile
+"""
+
+from __future__ import annotations
+
+import functools
+from dataclasses import dataclass
+from typing import Callable
+from typing_extensions import Literal
+
+import logging
+from functools import wraps, cache
+from contextlib import contextmanager
+
+import math
+import torch.nn as nn
+import random
+
+from einops import rearrange
+
+
+@dataclass
+class HypertileParams:
+ depth = 0
+ layer_name = ""
+ tile_size: int = 0
+ swap_size: int = 0
+ aspect_ratio: float = 1.0
+ forward = None
+ enabled = False
+
+
+
+# TODO add SD-XL layers
+DEPTH_LAYERS = {
+ 0: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.1.1.transformer_blocks.0.attn1",
+ "input_blocks.2.1.transformer_blocks.0.attn1",
+ "output_blocks.9.1.transformer_blocks.0.attn1",
+ "output_blocks.10.1.transformer_blocks.0.attn1",
+ "output_blocks.11.1.transformer_blocks.0.attn1",
+ # SD 1.5 VAE
+ "decoder.mid_block.attentions.0",
+ "decoder.mid.attn_1",
+ ],
+ 1: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.0.attn1",
+ "input_blocks.5.1.transformer_blocks.0.attn1",
+ "output_blocks.6.1.transformer_blocks.0.attn1",
+ "output_blocks.7.1.transformer_blocks.0.attn1",
+ "output_blocks.8.1.transformer_blocks.0.attn1",
+ ],
+ 2: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.7.1.transformer_blocks.0.attn1",
+ "input_blocks.8.1.transformer_blocks.0.attn1",
+ "output_blocks.3.1.transformer_blocks.0.attn1",
+ "output_blocks.4.1.transformer_blocks.0.attn1",
+ "output_blocks.5.1.transformer_blocks.0.attn1",
+ ],
+ 3: [
+ # SD 1.5 U-Net (diffusers)
+ "mid_block.attentions.0.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "middle_block.1.transformer_blocks.0.attn1",
+ ],
+}
+# XL layers, thanks for GitHub@gel-crabs for the help
+DEPTH_LAYERS_XL = {
+ 0: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.0.attn1",
+ "input_blocks.5.1.transformer_blocks.0.attn1",
+ "output_blocks.3.1.transformer_blocks.0.attn1",
+ "output_blocks.4.1.transformer_blocks.0.attn1",
+ "output_blocks.5.1.transformer_blocks.0.attn1",
+ # SD 1.5 VAE
+ "decoder.mid_block.attentions.0",
+ "decoder.mid.attn_1",
+ ],
+ 1: [
+ # SD 1.5 U-Net (diffusers)
+ #"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ #"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.1.attn1",
+ "input_blocks.5.1.transformer_blocks.1.attn1",
+ "output_blocks.3.1.transformer_blocks.1.attn1",
+ "output_blocks.4.1.transformer_blocks.1.attn1",
+ "output_blocks.5.1.transformer_blocks.1.attn1",
+ "input_blocks.7.1.transformer_blocks.0.attn1",
+ "input_blocks.8.1.transformer_blocks.0.attn1",
+ "output_blocks.0.1.transformer_blocks.0.attn1",
+ "output_blocks.1.1.transformer_blocks.0.attn1",
+ "output_blocks.2.1.transformer_blocks.0.attn1",
+ "input_blocks.7.1.transformer_blocks.1.attn1",
+ "input_blocks.8.1.transformer_blocks.1.attn1",
+ "output_blocks.0.1.transformer_blocks.1.attn1",
+ "output_blocks.1.1.transformer_blocks.1.attn1",
+ "output_blocks.2.1.transformer_blocks.1.attn1",
+ "input_blocks.7.1.transformer_blocks.2.attn1",
+ "input_blocks.8.1.transformer_blocks.2.attn1",
+ "output_blocks.0.1.transformer_blocks.2.attn1",
+ "output_blocks.1.1.transformer_blocks.2.attn1",
+ "output_blocks.2.1.transformer_blocks.2.attn1",
+ "input_blocks.7.1.transformer_blocks.3.attn1",
+ "input_blocks.8.1.transformer_blocks.3.attn1",
+ "output_blocks.0.1.transformer_blocks.3.attn1",
+ "output_blocks.1.1.transformer_blocks.3.attn1",
+ "output_blocks.2.1.transformer_blocks.3.attn1",
+ "input_blocks.7.1.transformer_blocks.4.attn1",
+ "input_blocks.8.1.transformer_blocks.4.attn1",
+ "output_blocks.0.1.transformer_blocks.4.attn1",
+ "output_blocks.1.1.transformer_blocks.4.attn1",
+ "output_blocks.2.1.transformer_blocks.4.attn1",
+ "input_blocks.7.1.transformer_blocks.5.attn1",
+ "input_blocks.8.1.transformer_blocks.5.attn1",
+ "output_blocks.0.1.transformer_blocks.5.attn1",
+ "output_blocks.1.1.transformer_blocks.5.attn1",
+ "output_blocks.2.1.transformer_blocks.5.attn1",
+ "input_blocks.7.1.transformer_blocks.6.attn1",
+ "input_blocks.8.1.transformer_blocks.6.attn1",
+ "output_blocks.0.1.transformer_blocks.6.attn1",
+ "output_blocks.1.1.transformer_blocks.6.attn1",
+ "output_blocks.2.1.transformer_blocks.6.attn1",
+ "input_blocks.7.1.transformer_blocks.7.attn1",
+ "input_blocks.8.1.transformer_blocks.7.attn1",
+ "output_blocks.0.1.transformer_blocks.7.attn1",
+ "output_blocks.1.1.transformer_blocks.7.attn1",
+ "output_blocks.2.1.transformer_blocks.7.attn1",
+ "input_blocks.7.1.transformer_blocks.8.attn1",
+ "input_blocks.8.1.transformer_blocks.8.attn1",
+ "output_blocks.0.1.transformer_blocks.8.attn1",
+ "output_blocks.1.1.transformer_blocks.8.attn1",
+ "output_blocks.2.1.transformer_blocks.8.attn1",
+ "input_blocks.7.1.transformer_blocks.9.attn1",
+ "input_blocks.8.1.transformer_blocks.9.attn1",
+ "output_blocks.0.1.transformer_blocks.9.attn1",
+ "output_blocks.1.1.transformer_blocks.9.attn1",
+ "output_blocks.2.1.transformer_blocks.9.attn1",
+ ],
+ 2: [
+ # SD 1.5 U-Net (diffusers)
+ "mid_block.attentions.0.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "middle_block.1.transformer_blocks.0.attn1",
+ "middle_block.1.transformer_blocks.1.attn1",
+ "middle_block.1.transformer_blocks.2.attn1",
+ "middle_block.1.transformer_blocks.3.attn1",
+ "middle_block.1.transformer_blocks.4.attn1",
+ "middle_block.1.transformer_blocks.5.attn1",
+ "middle_block.1.transformer_blocks.6.attn1",
+ "middle_block.1.transformer_blocks.7.attn1",
+ "middle_block.1.transformer_blocks.8.attn1",
+ "middle_block.1.transformer_blocks.9.attn1",
+ ],
+ 3 : [] # TODO - separate layers for SD-XL
+}
+
+
+RNG_INSTANCE = random.Random()
+
+
+def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
+ """
+ Returns a random divisor of value that
+ x * min_value <= value
+ if max_options is 1, the behavior is deterministic
+ """
+ min_value = min(min_value, value)
+
+ # All big divisors of value (inclusive)
+ divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
+
+ ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
+
+ idx = RNG_INSTANCE.randint(0, len(ns) - 1)
+
+ return ns[idx]
+
+
+def set_hypertile_seed(seed: int) -> None:
+ RNG_INSTANCE.seed(seed)
+
+
+@functools.cache
+def largest_tile_size_available(width: int, height: int) -> int:
+ """
+ Calculates the largest tile size available for a given width and height
+ Tile size is always a power of 2
+ """
+ gcd = math.gcd(width, height)
+ largest_tile_size_available = 1
+ while gcd % (largest_tile_size_available * 2) == 0:
+ largest_tile_size_available *= 2
+ return largest_tile_size_available
+
+
+def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
+ """
+ Finds h and w such that h*w = hw and h/w = aspect_ratio
+ We check all possible divisors of hw and return the closest to the aspect ratio
+ """
+ divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
+ pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
+ ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
+ closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
+ closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
+ return closest_pair
+
+
+@cache
+def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
+ """
+ Finds h and w such that h*w = hw and h/w = aspect_ratio
+ """
+ h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
+ # find h and w such that h*w = hw and h/w = aspect_ratio
+ if h * w != hw:
+ w_candidate = hw / h
+ # check if w is an integer
+ if not w_candidate.is_integer():
+ h_candidate = hw / w
+ # check if h is an integer
+ if not h_candidate.is_integer():
+ return iterative_closest_divisors(hw, aspect_ratio)
+ else:
+ h = int(h_candidate)
+ else:
+ w = int(w_candidate)
+ return h, w
+
+
+def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
+
+ @wraps(params.forward)
+ def wrapper(*args, **kwargs):
+ if not params.enabled:
+ return params.forward(*args, **kwargs)
+
+ latent_tile_size = max(128, params.tile_size) // 8
+ x = args[0]
+
+ # VAE
+ if x.ndim == 4:
+ b, c, h, w = x.shape
+
+ nh = random_divisor(h, latent_tile_size, params.swap_size)
+ nw = random_divisor(w, latent_tile_size, params.swap_size)
+
+ if nh * nw > 1:
+ x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
+
+ out = params.forward(x, *args[1:], **kwargs)
+
+ if nh * nw > 1:
+ out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
+
+ # U-Net
+ else:
+ hw: int = x.size(1)
+ h, w = find_hw_candidates(hw, params.aspect_ratio)
+ assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
+
+ factor = 2 ** params.depth if scale_depth else 1
+ nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
+ nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
+
+ if nh * nw > 1:
+ x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
+
+ out = params.forward(x, *args[1:], **kwargs)
+
+ if nh * nw > 1:
+ out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
+ out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
+
+ return out
+
+ return wrapper
+
+
+def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
+ hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
+ if hypertile_layers is None:
+ if not enable:
+ return
+
+ hypertile_layers = {}
+ layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
+
+ for depth in range(4):
+ for layer_name, module in model.named_modules():
+ if any(layer_name.endswith(try_name) for try_name in layers[depth]):
+ params = HypertileParams()
+ module.__webui_hypertile_params = params
+ params.forward = module.forward
+ params.depth = depth
+ params.layer_name = layer_name
+ module.forward = self_attn_forward(params)
+
+ hypertile_layers[layer_name] = 1
+
+ model.__webui_hypertile_layers = hypertile_layers
+
+ aspect_ratio = width / height
+ tile_size = min(largest_tile_size_available(width, height), tile_size_max)
+
+ for layer_name, module in model.named_modules():
+ if layer_name in hypertile_layers:
+ params = module.__webui_hypertile_params
+
+ params.tile_size = tile_size
+ params.swap_size = swap_size
+ params.aspect_ratio = aspect_ratio
+ params.enabled = enable and params.depth <= max_depth
diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py
new file mode 100644
index 00000000..3cc29cd1
--- /dev/null
+++ b/extensions-builtin/hypertile/scripts/hypertile_script.py
@@ -0,0 +1,73 @@
+import hypertile
+from modules import scripts, script_callbacks, shared
+
+
+class ScriptHypertile(scripts.Script):
+ name = "Hypertile"
+
+ def title(self):
+ return self.name
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def process(self, p, *args):
+ hypertile.set_hypertile_seed(p.all_seeds[0])
+
+ configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
+
+ def before_hr(self, p, *args):
+ configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)
+
+
+def configure_hypertile(width, height, enable_unet=True):
+ hypertile.hypertile_hook_model(
+ shared.sd_model.first_stage_model,
+ width,
+ height,
+ swap_size=shared.opts.hypertile_swap_size_vae,
+ max_depth=shared.opts.hypertile_max_depth_vae,
+ tile_size_max=shared.opts.hypertile_max_tile_vae,
+ enable=shared.opts.hypertile_enable_vae,
+ )
+
+ hypertile.hypertile_hook_model(
+ shared.sd_model.model,
+ width,
+ height,
+ swap_size=shared.opts.hypertile_swap_size_unet,
+ max_depth=shared.opts.hypertile_max_depth_unet,
+ tile_size_max=shared.opts.hypertile_max_tile_unet,
+ enable=enable_unet,
+ is_sdxl=shared.sd_model.is_sdxl
+ )
+
+
+def on_ui_settings():
+ import gradio as gr
+
+ options = {
+ "hypertile_explanation": shared.OptionHTML("""
+ <a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
+ resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
+ benefit.
+ """),
+
+ "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
+ "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
+ "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
+ "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
+
+ "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
+ "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
+ "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
+ }
+
+ for name, opt in options.items():
+ opt.section = ('hypertile', "Hypertile")
+ shared.opts.add_option(name, opt)
+
+
+script_callbacks.on_ui_settings(on_ui_settings)