aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_hijack_optimizations.py66
1 files changed, 28 insertions, 38 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index db1e4367..0eb4c525 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -19,10 +19,10 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At
class SdOptimization:
- def __init__(self, name, label=None, cmd_opt=None):
- self.name = name
- self.label = label
- self.cmd_opt = cmd_opt
+ name: str = None
+ label: str | None = None
+ cmd_opt: str | None = None
+ priority: int = 0
def title(self):
if self.label is None:
@@ -33,9 +33,6 @@ class SdOptimization:
def is_available(self):
return True
- def priority(self):
- return 0
-
def apply(self):
pass
@@ -45,41 +42,37 @@ class SdOptimization:
class SdOptimizationXformers(SdOptimization):
- def __init__(self):
- super().__init__("xformers", cmd_opt="xformers")
+ name = "xformers"
+ cmd_opt = "xformers"
+ priority = 100
def is_available(self):
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
- def priority(self):
- return 100
-
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
class SdOptimizationSdpNoMem(SdOptimization):
- def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"):
- super().__init__(name, label, cmd_opt)
+ name = "sdp-no-mem"
+ label = "scaled dot product without memory efficient attention"
+ cmd_opt = "opt_sdp_no_mem_attention"
+ priority = 90
def is_available(self):
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
- def priority(self):
- return 90
-
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
class SdOptimizationSdp(SdOptimizationSdpNoMem):
- def __init__(self):
- super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention")
-
- def priority(self):
- return 80
+ name = "sdp"
+ label = "scaled dot product"
+ cmd_opt = "opt_sdp_attention"
+ priority = 80
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
@@ -87,11 +80,9 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization):
- def __init__(self):
- super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention")
-
- def priority(self):
- return 10
+ name = "sub-quadratic"
+ cmd_opt = "opt_sub_quad_attention"
+ priority = 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
@@ -99,20 +90,21 @@ class SdOptimizationSubQuad(SdOptimization):
class SdOptimizationV1(SdOptimization):
- def __init__(self):
- super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1")
+ name = "V1"
+ label = "original v1"
+ cmd_opt = "opt_split_attention_v1"
+ priority = 10
- def priority(self):
- return 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
class SdOptimizationInvokeAI(SdOptimization):
- def __init__(self):
- super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai")
+ name = "InvokeAI"
+ cmd_opt = "opt_split_attention_invokeai"
+ @property
def priority(self):
return 1000 if not torch.cuda.is_available() else 10
@@ -121,11 +113,9 @@ class SdOptimizationInvokeAI(SdOptimization):
class SdOptimizationDoggettx(SdOptimization):
- def __init__(self):
- super().__init__("Doggettx", cmd_opt="opt_split_attention")
-
- def priority(self):
- return 20
+ name = "Doggettx"
+ cmd_opt = "opt_split_attention"
+ priority = 20
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward