From d2e0c1ca132f4f0d98b77397a9f353d4ad8e7c4b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 10:51:45 +0300 Subject: rework hypertile into a built-in extension --- README.md | 1 + extensions-builtin/hypertile/hypertile.py | 221 +++++++++------------ .../hypertile/scripts/hypertile_script.py | 73 +++++++ modules/processing.py | 37 ++-- modules/shared_options.py | 8 - 5 files changed, 186 insertions(+), 154 deletions(-) create mode 100644 extensions-builtin/hypertile/scripts/hypertile_script.py diff --git a/README.md b/README.md index 25ba070e..3b3f93ad 100644 --- a/README.md +++ b/README.md @@ -174,5 +174,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - LyCORIS - KohakuBlueleaf - Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling +- Hypertile - tfernd - https://github.com/tfernd/HyperTile - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py index be898fce..a40c1311 100644 --- a/extensions-builtin/hypertile/hypertile.py +++ b/extensions-builtin/hypertile/hypertile.py @@ -1,10 +1,13 @@ """ Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE -Warn : The patch works well only if the input image has a width and height that are multiples of 128 -Author : @tfernd Github : https://github.com/tfernd/HyperTile +Warn: The patch works well only if the input image has a width and height that are multiples of 128 +Original author: @tfernd Github: https://github.com/tfernd/HyperTile """ from __future__ import annotations + +import functools +from dataclasses import dataclass from typing import Callable from typing_extensions import Literal @@ -18,6 +21,19 @@ import random from einops import rearrange + +@dataclass +class HypertileParams: + depth = 0 + layer_name = "" + tile_size: int = 0 + swap_size: int = 0 + aspect_ratio: float = 1.0 + forward = None + enabled = False + + + # TODO add SD-XL layers DEPTH_LAYERS = { 0: [ @@ -176,6 +192,7 @@ DEPTH_LAYERS_XL = { RNG_INSTANCE = random.Random() + def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: """ Returns a random divisor of value that @@ -193,10 +210,13 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: return ns[idx] + def set_hypertile_seed(seed: int) -> None: RNG_INSTANCE.seed(seed) -def largest_tile_size_available(width:int, height:int) -> int: + +@functools.cache +def largest_tile_size_available(width: int, height: int) -> int: """ Calculates the largest tile size available for a given width and height Tile size is always a power of 2 @@ -207,6 +227,7 @@ def largest_tile_size_available(width:int, height:int) -> int: largest_tile_size_available *= 2 return largest_tile_size_available + def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: """ Finds h and w such that h*w = hw and h/w = aspect_ratio @@ -219,6 +240,7 @@ def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio return closest_pair + @cache def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: """ @@ -240,132 +262,87 @@ def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: w = int(w_candidate) return h, w -@contextmanager -def split_attention( - layer: nn.Module, - /, - aspect_ratio: float, # width/height - tile_size: int = 128, # 128 for VAE - swap_size: int = 1, # 1 for VAE - *, - disable: bool = False, - max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1 - scale_depth: bool = True, # scale the tile-size depending on the depth - is_sdxl: bool = False, # is the model SD-XL -): - # Hijacks AttnBlock from ldm and Attention from diffusers - - if disable: - logging.info(f"Attention for {layer.__class__.__qualname__} not splitted") - yield - return - - latent_tile_size = max(128, tile_size) // 8 - - def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable: - @wraps(forward) - def wrapper(*args, **kwargs): - x = args[0] - - # VAE - if x.ndim == 4: - b, c, h, w = x.shape - - nh = random_divisor(h, latent_tile_size, swap_size) - nw = random_divisor(w, latent_tile_size, swap_size) - - if nh * nw > 1: - x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles - - out = forward(x, *args[1:], **kwargs) - - if nh * nw > 1: - out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) - - # U-Net - else: - hw: int = x.size(1) - h, w = find_hw_candidates(hw, aspect_ratio) - assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" - factor = 2**depth if scale_depth else 1 - nh = random_divisor(h, latent_tile_size * factor, swap_size) - nw = random_divisor(w, latent_tile_size * factor, swap_size) +def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable: + + @wraps(params.forward) + def wrapper(*args, **kwargs): + if not params.enabled: + return params.forward(*args, **kwargs) - module._split_sizes_hypertile.append((nh, nw)) # type: ignore + latent_tile_size = max(128, params.tile_size) // 8 + x = args[0] - if nh * nw > 1: - x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + # VAE + if x.ndim == 4: + b, c, h, w = x.shape - out = forward(x, *args[1:], **kwargs) + nh = random_divisor(h, latent_tile_size, params.swap_size) + nw = random_divisor(w, latent_tile_size, params.swap_size) - if nh * nw > 1: - out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) - out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + if nh * nw > 1: + x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles - return out + out = params.forward(x, *args[1:], **kwargs) - return wrapper + if nh * nw > 1: + out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) - # Handle hijacking the forward method and recovering afterwards - try: - if is_sdxl: - layers = DEPTH_LAYERS_XL + # U-Net else: - layers = DEPTH_LAYERS - for depth in range(max_depth + 1): - for layer_name, module in layer.named_modules(): + hw: int = x.size(1) + h, w = find_hw_candidates(hw, params.aspect_ratio) + assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" + + factor = 2 ** params.depth if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, params.swap_size) + nw = random_divisor(w, latent_tile_size * factor, params.swap_size) + + if nh * nw > 1: + x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + + out = params.forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + + return out + + return wrapper + + +def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False): + hypertile_layers = getattr(model, "__webui_hypertile_layers", None) + if hypertile_layers is None: + if not enable: + return + + hypertile_layers = {} + layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS + + for depth in range(4): + for layer_name, module in model.named_modules(): if any(layer_name.endswith(try_name) for try_name in layers[depth]): - # print input shape for debugging - logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}") - # hijack - module._original_forward_hypertile = module.forward - module.forward = self_attn_forward(module.forward, depth, layer_name, module) - module._split_sizes_hypertile = [] - yield - finally: - for layer_name, module in layer.named_modules(): - # remove hijack - if hasattr(module, "_original_forward_hypertile"): - if module._split_sizes_hypertile: - logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})") - # recover - module.forward = module._original_forward_hypertile - del module._original_forward_hypertile - del module._split_sizes_hypertile - -def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts): - """ - Returns context manager for VAE - """ - enabled = opts.hypertile_split_vae_attn - swap_size = opts.hypertile_swap_size_vae - max_depth = opts.hypertile_max_depth_vae - tile_size_max = opts.hypertile_max_tile_vae - return split_attention( - model, - aspect_ratio=aspect_ratio, - tile_size=min(tile_size, tile_size_max), - swap_size=swap_size, - disable=not enabled, - max_depth=max_depth, - is_sdxl=False, - ) - -def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool): - """ - Returns context manager for U-Net - """ - enabled = opts.hypertile_split_unet_attn - swap_size = opts.hypertile_swap_size_unet - max_depth = opts.hypertile_max_depth_unet - tile_size_max = opts.hypertile_max_tile_unet - return split_attention( - model, - aspect_ratio=aspect_ratio, - tile_size=min(tile_size, tile_size_max), - swap_size=swap_size, - disable=not enabled, - max_depth=max_depth, - is_sdxl=is_sdxl, - ) + params = HypertileParams() + module.__webui_hypertile_params = params + params.forward = module.forward + params.depth = depth + params.layer_name = layer_name + module.forward = self_attn_forward(params) + + hypertile_layers[layer_name] = 1 + + model.__webui_hypertile_layers = hypertile_layers + + aspect_ratio = width / height + tile_size = min(largest_tile_size_available(width, height), tile_size_max) + + for layer_name, module in model.named_modules(): + if layer_name in hypertile_layers: + params = module.__webui_hypertile_params + + params.tile_size = tile_size + params.swap_size = swap_size + params.aspect_ratio = aspect_ratio + params.enabled = enable and params.depth <= max_depth diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py new file mode 100644 index 00000000..3cc29cd1 --- /dev/null +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -0,0 +1,73 @@ +import hypertile +from modules import scripts, script_callbacks, shared + + +class ScriptHypertile(scripts.Script): + name = "Hypertile" + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def process(self, p, *args): + hypertile.set_hypertile_seed(p.all_seeds[0]) + + configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + + def before_hr(self, p, *args): + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet) + + +def configure_hypertile(width, height, enable_unet=True): + hypertile.hypertile_hook_model( + shared.sd_model.first_stage_model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_vae, + max_depth=shared.opts.hypertile_max_depth_vae, + tile_size_max=shared.opts.hypertile_max_tile_vae, + enable=shared.opts.hypertile_enable_vae, + ) + + hypertile.hypertile_hook_model( + shared.sd_model.model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_unet, + max_depth=shared.opts.hypertile_max_depth_unet, + tile_size_max=shared.opts.hypertile_max_tile_unet, + enable=enable_unet, + is_sdxl=shared.sd_model.is_sdxl + ) + + +def on_ui_settings(): + import gradio as gr + + options = { + "hypertile_explanation": shared.OptionHTML(""" + Hypertile optimizes the self-attention layer within U-Net and VAE models, + resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the + benefit. + """), + + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), + + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), + } + + for name, opt in options.items(): + opt.section = ('hypertile', "Hypertile") + shared.opts.add_option(name, opt) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/modules/processing.py b/modules/processing.py index 36c2be5e..ac58ef86 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -24,7 +24,6 @@ from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths import modules.face_restoration -from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae import modules.images as images import modules.styles import modules.sd_models as sd_models @@ -861,8 +860,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.comment(comment) p.extra_generation_params.update(model_hijack.extra_generation_params) - set_hypertile_seed(p.seed) - # add batch size + hypertile status to information to reproduce the run + if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" @@ -874,8 +872,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts): - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -1141,25 +1138,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - aspect_ratio = self.width / self.height + x = self.rng.next() - tile_size = largest_tile_size_available(self.width, self.height) - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x + if not self.enable_hr: return samples devices.torch_gc() if self.latent_scale_mode is None: - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) + return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): @@ -1244,18 +1239,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scripts is not None: self.scripts.before_hr(self) - tile_size = largest_tile_size_available(target_width, target_height) - aspect_ratio = self.width / self.height - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) + + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.sampler = None devices.torch_gc() - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False return decoded_samples @@ -1532,11 +1524,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.initial_noise_multiplier != 1.0: self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier x *= self.initial_noise_multiplier - aspect_ratio = self.width / self.height - tile_size = largest_tile_size_available(self.width, self.height) - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) + + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask diff --git a/modules/shared_options.py b/modules/shared_options.py index 28a48906..d40db530 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,14 +200,6 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), - "hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), - "hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), - "hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { -- cgit v1.2.1