aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-05-08 18:16:01 -0400
committerbrkirch <brkirch@users.noreply.github.com>2023-08-13 10:06:25 -0400
commitabfa4ad8bc995dcaf832c07a7cf75b6e295a8ca9 (patch)
tree25ba8016073b738292c40a2ecfcf714533198f85 /modules/sd_hijack_optimizations.py
parent3163d1269af7f9fd95382e58bb1581fd741b5119 (diff)
Use fixed size for sub-quadratic chunking on MPS
Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low.
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 0e810eec..b3e71270 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import math
import psutil
+import platform
import torch
from torch import einsum
@@ -427,7 +428,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None:
- chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+ if q.device.type == 'mps':
+ chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
+ else:
+ chunk_threshold_bytes = int(get_available_vram() * 0.7)
elif chunk_threshold == 0:
chunk_threshold_bytes = None
else: