aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-05-21 05:00:27 -0400
committerbrkirch <brkirch@users.noreply.github.com>2023-08-13 10:06:25 -0400
commit87dd685224b5f7dbbd832fc73cc08e7e470c9f28 (patch)
treefbfd8c08f158918cb237a6acfa9a2192c1f7dc95 /modules
parentabfa4ad8bc995dcaf832c07a7cf75b6e295a8ca9 (diff)
Make sub-quadratic the default for MPS
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_hijack_optimizations.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index b3e71270..7f9e328d 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -95,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention"
- priority = 10
+
+ @property
+ def priority(self):
+ return 1000 if shared.device.type == 'mps' else 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
@@ -121,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
@property
def priority(self):
- return 1000 if not torch.cuda.is_available() else 10
+ return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI