aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py108
1 files changed, 59 insertions, 49 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f4bb0266..08d31080 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -28,57 +28,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
+optimizers = []
+current_optimizer: sd_hijack_optimizations.SdOptimization = None
+
+
+def list_optimizers():
+ new_optimizers = script_callbacks.list_optimizers_callback()
+
+ new_optimizers = [x for x in new_optimizers if x.is_available()]
+
+ new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
+
+ optimizers.clear()
+ optimizers.extend(new_optimizers)
+
def apply_optimizations():
+ global current_optimizer
+
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
-
- optimization_method = None
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
-
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
- print("Applying xformers cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
- optimization_method = 'xformers'
- elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
- optimization_method = 'sdp-no-mem'
- elif cmd_opts.opt_sdp_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
- optimization_method = 'sdp'
- elif cmd_opts.opt_sub_quad_attention:
- print("Applying sub-quadratic cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
- optimization_method = 'sub-quadratic'
- elif cmd_opts.opt_split_attention_v1:
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization (Doggettx).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- optimization_method = 'Doggettx'
-
- return optimization_method
+ if current_optimizer is not None:
+ current_optimizer.undo()
+ current_optimizer = None
+
+ selection = shared.opts.cross_attention_optimization
+ if selection == "Automatic" and len(optimizers) > 0:
+ matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
+ else:
+ matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
+
+ if selection == "None":
+ matching_optimizer = None
+ elif matching_optimizer is None:
+ matching_optimizer = optimizers[0]
+
+ if matching_optimizer is not None:
+ print(f"Applying optimization: {matching_optimizer.name}")
+ matching_optimizer.apply()
+ current_optimizer = matching_optimizer
+ return current_optimizer.name
+ else:
+ return ''
def undo_optimizations():
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
@@ -92,12 +91,12 @@ def fix_checkpoint():
def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False)
-
+
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
-
+
#Return the loss, as mean if specified
return loss.mean() if mean else loss
@@ -105,7 +104,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w
-
+
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
@@ -118,9 +117,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
- except AttributeError as e:
+ except AttributeError:
pass
-
+
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss
@@ -133,7 +132,7 @@ def apply_weighted_forward(sd_model):
def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
- except AttributeError as e:
+ except AttributeError:
pass
@@ -169,7 +168,11 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
- self.optimization_method = apply_optimizations()
+ try:
+ self.optimization_method = apply_optimizations()
+ except Exception as e:
+ errors.display(e, "applying cross attention optimization")
+ undo_optimizations()
self.clip = m.cond_stage_model
@@ -184,7 +187,7 @@ class StableDiffusionModelHijack:
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- m.cond_stage_model = m.cond_stage_model.wrapped
+ m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -216,10 +219,17 @@ class StableDiffusionModelHijack:
self.comments = []
def get_prompt_lengths(self, text):
+ if self.clip is None:
+ return "-", "-"
+
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)
+ def redo_hijack(self, m):
+ self.undo_hijack(m)
+ self.hijack(m)
+
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):