aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py45
1 files changed, 15 insertions, 30 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 634fb4b2..3349b9c3 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,12 +9,12 @@ from ldm.util import default
from einops import rearrange
from modules import shared
+from modules.hypernetworks import hypernetwork
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
import xformers.ops
- import functorch
- xformers._is_functorch_available = True
shared.xformers_available = True
except Exception:
print("Cannot import xformers", file=sys.stderr)
@@ -28,16 +28,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
- del context, x
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+ del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
@@ -61,22 +55,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
return self.to_out(r2)
-# taken from https://github.com/Doggettx/stable-diffusion
+# taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
k_in *= self.scale
@@ -132,14 +120,11 @@ def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)