aboutsummaryrefslogtreecommitdiff
path: root/modules/hypertile.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypertile.py')
-rw-r--r--modules/hypertile.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/modules/hypertile.py b/modules/hypertile.py
index 32d8604c..fee24a8c 100644
--- a/modules/hypertile.py
+++ b/modules/hypertile.py
@@ -332,3 +332,39 @@ def split_attention(
module.forward = module._original_forward_hypertile
del module._original_forward_hypertile
del module._split_sizes_hypertile
+
+def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts):
+ """
+ Returns context manager for VAE
+ """
+ enabled = not opts.hypertile_split_vae_attn
+ swap_size = opts.hypertile_swap_size_vae
+ max_depth = opts.hypertile_max_depth_vae
+ tile_size_max = opts.hypertile_max_tile_vae
+ return split_attention(
+ model,
+ aspect_ratio=aspect_ratio,
+ tile_size=min(tile_size, tile_size_max),
+ swap_size=swap_size,
+ disable=not enabled,
+ max_depth=max_depth,
+ is_sdxl=False,
+ )
+
+def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool):
+ """
+ Returns context manager for U-Net
+ """
+ enabled = not opts.hypertile_split_unet_attn
+ swap_size = opts.hypertile_swap_size_unet
+ max_depth = opts.hypertile_max_depth_unet
+ tile_size_max = opts.hypertile_max_tile_unet
+ return split_attention(
+ model,
+ aspect_ratio=aspect_ratio,
+ tile_size=min(tile_size, tile_size_max),
+ swap_size=swap_size,
+ disable=not enabled,
+ max_depth=max_depth,
+ is_sdxl=is_sdxl,
+ ) \ No newline at end of file