aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--launch.py8
-rw-r--r--modules/sd_hijack.py10
-rw-r--r--modules/sd_hijack_optimizations.py38
-rw-r--r--modules/shared.py3
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
6 files changed, 55 insertions, 6 deletions
diff --git a/launch.py b/launch.py
index 75edb66a..a592e1ba 100644
--- a/launch.py
+++ b/launch.py
@@ -4,6 +4,7 @@ import os
import sys
import importlib.util
import shlex
+import platform
dir_repos = "repositories"
dir_tmp = "tmp"
@@ -31,6 +32,7 @@ def extract_arg(args, name):
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
+args, xformers = extract_arg(args, '--xformers')
def repo_dir(name):
@@ -124,6 +126,12 @@ if not is_installed("gfpgan"):
if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
+if not is_installed("xformers") and xformers:
+ if platform.system() == "Windows":
+ run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
+ elif platform.system() == "Linux":
+ run_pip("install xformers", "xformers")
+
os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index ba808a39..5d93f7f6 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -22,11 +22,13 @@ def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
-
- if cmd_opts.opt_split_attention_v1:
+ if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available:
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ elif cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
+ elif cmd_opts.opt_split_attention or torch.cuda.is_available():
+ ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 3351c740..e43e2c7a 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,14 @@
import math
import torch
from torch import einsum
-
+try:
+ import xformers.ops
+ import functorch
+ xformers._is_functorch_available = True
+ shared.xformers_available = True
+except:
+ print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.')
+ continue
from ldm.util import default
from einops import rearrange
@@ -115,6 +122,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
+def xformers_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+ return self.to_out(out)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -177,3 +203,13 @@ def cross_attention_attnblock_forward(self, x):
h3 += x
return h3
+
+def xformers_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_).contiguous()
+ k1 = self.k(h_).contiguous()
+ v = self.v(h_).contiguous()
+ out = xformers.ops.memory_efficient_attention(q1, k1, v)
+ out = self.proj_out(out)
+ return x+out
diff --git a/modules/shared.py b/modules/shared.py
index 475d7e52..d68df751 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
+parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
@@ -73,7 +74,7 @@ device = devices.device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
-
+xformers_available = False
config_filename = cmd_opts.ui_settings_file
hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
diff --git a/requirements.txt b/requirements.txt
index 631fe616..81641d68 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,3 +23,4 @@ resize-right
torchdiffeq
kornia
lark
+functorch
diff --git a/requirements_versions.txt b/requirements_versions.txt
index fdff2687..fec3e9d5 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -22,3 +22,4 @@ resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
lark==1.1.2
+functorch==0.2.1