aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorTaithrah <Taithrah@users.noreply.github.com>2023-01-08 15:58:53 -0500
committerGitHub <noreply@github.com>2023-01-08 15:58:53 -0500
commite9d7eff70a3429ee299cbdcae1aeb61fc4d2bcbf (patch)
tree2373cae4f4e4af72ed170647bf393015075791cc /modules/sd_hijack.py
parent8a27730da5d5b25e28370e8ad94844856a839af9 (diff)
parent8850fc23b6e8a8e210bdfe4aade81516fb5770f3 (diff)
Merge branch 'AUTOMATIC1111:master' into small-touch-up
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py28
1 files changed, 13 insertions, 15 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 71cc145a..6b0d95af 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
-from modules.sd_hijack_optimizations import invokeAI_mps_available
-
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
@@ -43,20 +41,19 @@ def apply_optimizations():
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
+ elif cmd_opts.opt_sub_quad_attention:
+ print("Applying sub-quadratic cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
+ optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
- if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- else:
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
@@ -86,10 +83,12 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
- def hijack(self, m):
+ def __init__(self):
+ self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -120,7 +119,6 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
-
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped