aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--javascript/edit-attention.js2
-rw-r--r--launch.py2
-rw-r--r--modules/cmd_args.py14
-rw-r--r--modules/extensions.py7
-rw-r--r--modules/extra_networks.py5
-rw-r--r--modules/script_callbacks.py21
-rw-r--r--modules/scripts.py21
-rw-r--r--modules/sd_hijack.py89
-rw-r--r--modules/sd_hijack_optimizations.py125
-rw-r--r--modules/sd_models.py1
-rw-r--r--modules/sd_samplers_kdiffusion.py14
-rw-r--r--modules/shared.py8
-rw-r--r--modules/shared_items.py8
-rw-r--r--modules/ui.py8
-rw-r--r--modules/ui_extensions.py4
-rw-r--r--modules/ui_extra_networks.py11
-rw-r--r--webui.py237
17 files changed, 389 insertions, 188 deletions
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index fdf00b4d..ffa73147 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -1,6 +1,6 @@
function keyupEditAttention(event) {
let target = event.originalTarget || event.composedPath()[0];
- if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
+ if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
if (!(event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp";
diff --git a/launch.py b/launch.py
index b9b5b709..1d504c38 100644
--- a/launch.py
+++ b/launch.py
@@ -236,7 +236,7 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
- k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 8625690b..3eeb84d5 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
-parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
+parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
+parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
-parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
-parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
-parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
-parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
+parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
+parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
+parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
diff --git a/modules/extensions.py b/modules/extensions.py
index 359a7aa5..624832a0 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -65,11 +65,12 @@ class Extension:
try:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
- self.commit_date = repo.head.commit.committed_date
+ commit = repo.head.commit
+ self.commit_date = commit.committed_date
if repo.active_branch:
self.branch = repo.active_branch.name
- self.commit_hash = repo.head.commit.hexsha
- self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
+ self.commit_hash = commit.hexsha
+ self.version = self.commit_hash[:8]
except Exception as ex:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 54982009..34a3ba63 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -14,6 +14,11 @@ def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_default_extra_networks():
+ from modules.extra_networks_hypernet import ExtraNetworkHypernet
+ register_extra_network(ExtraNetworkHypernet())
+
+
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 3c21a362..40f388a5 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -110,6 +110,7 @@ callback_map = dict(
callbacks_script_unloaded=[],
callbacks_before_ui=[],
callbacks_on_reload=[],
+ callbacks_list_optimizers=[],
)
@@ -258,6 +259,18 @@ def before_ui_callback():
report_exception(c, 'before_ui')
+def list_optimizers_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_optimizers']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_optimizers')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -409,3 +422,11 @@ def on_before_ui(callback):
"""register a function to be called before the UI is created."""
add_callback(callback_map['callbacks_before_ui'], callback)
+
+
+def on_list_optimizers(callback):
+ """register a function to be called when UI is making a list of cross attention optimization options.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
+ to it."""
+
+ add_callback(callback_map['callbacks_list_optimizers'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index e33d8c81..c902804b 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -271,6 +271,12 @@ def load_scripts():
sys.path = syspath
current_basedir = paths.script_path
+ global scripts_txt2img, scripts_img2img, scripts_postproc
+
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
@@ -527,9 +533,9 @@ class ScriptRunner:
self.scripts[si].args_to = args_to
-scripts_txt2img = ScriptRunner()
-scripts_img2img = ScriptRunner()
-scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+scripts_txt2img: ScriptRunner = None
+scripts_img2img: ScriptRunner = None
+scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
scripts_current: ScriptRunner = None
@@ -539,14 +545,7 @@ def reload_script_body_only():
scripts_img2img.reload_sources(cache)
-def reload_scripts():
- global scripts_txt2img, scripts_img2img, scripts_postproc
-
- load_scripts()
-
- scripts_txt2img = ScriptRunner()
- scripts_img2img = ScriptRunner()
- scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+reload_scripts = load_scripts # compatibility alias
def add_classes_to_gradio_component(comp):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 14e7f799..08d31080 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -28,57 +28,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
+optimizers = []
+current_optimizer: sd_hijack_optimizations.SdOptimization = None
+
+
+def list_optimizers():
+ new_optimizers = script_callbacks.list_optimizers_callback()
+
+ new_optimizers = [x for x in new_optimizers if x.is_available()]
+
+ new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
+
+ optimizers.clear()
+ optimizers.extend(new_optimizers)
+
def apply_optimizations():
+ global current_optimizer
+
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
- optimization_method = None
+ if current_optimizer is not None:
+ current_optimizer.undo()
+ current_optimizer = None
+
+ selection = shared.opts.cross_attention_optimization
+ if selection == "Automatic" and len(optimizers) > 0:
+ matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
+ else:
+ matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
-
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
- print("Applying xformers cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
- optimization_method = 'xformers'
- elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
- optimization_method = 'sdp-no-mem'
- elif cmd_opts.opt_sdp_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
- optimization_method = 'sdp'
- elif cmd_opts.opt_sub_quad_attention:
- print("Applying sub-quadratic cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
- optimization_method = 'sub-quadratic'
- elif cmd_opts.opt_split_attention_v1:
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization (Doggettx).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- optimization_method = 'Doggettx'
-
- return optimization_method
+ if selection == "None":
+ matching_optimizer = None
+ elif matching_optimizer is None:
+ matching_optimizer = optimizers[0]
+
+ if matching_optimizer is not None:
+ print(f"Applying optimization: {matching_optimizer.name}")
+ matching_optimizer.apply()
+ current_optimizer = matching_optimizer
+ return current_optimizer.name
+ else:
+ return ''
def undo_optimizations():
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
@@ -169,7 +168,11 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
- self.optimization_method = apply_optimizations()
+ try:
+ self.optimization_method = apply_optimizations()
+ except Exception as e:
+ errors.display(e, "applying cross attention optimization")
+ undo_optimizations()
self.clip = m.cond_stage_model
@@ -223,6 +226,10 @@ class StableDiffusionModelHijack:
return token_count, self.clip.get_target_prompt_token_count(token_count)
+ def redo_hijack(self, m):
+ self.undo_hijack(m)
+ self.hijack(m)
+
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index f00fe55c..0eb4c525 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,10 +9,129 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, errors, devices
+from modules import shared, errors, devices, sub_quadratic_attention
from modules.hypernetworks import hypernetwork
-from .sub_quadratic_attention import efficient_dot_product_attention
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.model
+
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+
+class SdOptimization:
+ name: str = None
+ label: str | None = None
+ cmd_opt: str | None = None
+ priority: int = 0
+
+ def title(self):
+ if self.label is None:
+ return self.name
+
+ return f"{self.name} - {self.label}"
+
+ def is_available(self):
+ return True
+
+ def apply(self):
+ pass
+
+ def undo(self):
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
+
+class SdOptimizationXformers(SdOptimization):
+ name = "xformers"
+ cmd_opt = "xformers"
+ priority = 100
+
+ def is_available(self):
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+
+
+class SdOptimizationSdpNoMem(SdOptimization):
+ name = "sdp-no-mem"
+ label = "scaled dot product without memory efficient attention"
+ cmd_opt = "opt_sdp_no_mem_attention"
+ priority = 90
+
+ def is_available(self):
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+
+
+class SdOptimizationSdp(SdOptimizationSdpNoMem):
+ name = "sdp"
+ label = "scaled dot product"
+ cmd_opt = "opt_sdp_attention"
+ priority = 80
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+
+
+class SdOptimizationSubQuad(SdOptimization):
+ name = "sub-quadratic"
+ cmd_opt = "opt_sub_quad_attention"
+ priority = 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+
+
+class SdOptimizationV1(SdOptimization):
+ name = "V1"
+ label = "original v1"
+ cmd_opt = "opt_split_attention_v1"
+ priority = 10
+
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+
+
+class SdOptimizationInvokeAI(SdOptimization):
+ name = "InvokeAI"
+ cmd_opt = "opt_split_attention_invokeai"
+
+ @property
+ def priority(self):
+ return 1000 if not torch.cuda.is_available() else 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+
+
+class SdOptimizationDoggettx(SdOptimization):
+ name = "Doggettx"
+ cmd_opt = "opt_split_attention"
+ priority = 20
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+
+
+def list_optimizers(res):
+ res.extend([
+ SdOptimizationXformers(),
+ SdOptimizationSdpNoMem(),
+ SdOptimizationSdp(),
+ SdOptimizationSubQuad(),
+ SdOptimizationV1(),
+ SdOptimizationInvokeAI(),
+ SdOptimizationDoggettx(),
+ ])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -299,7 +418,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
+ return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 8e42bfea..b1afbaa7 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -98,7 +98,6 @@ def setup_model():
if not os.path.exists(model_path):
os.makedirs(model_path)
- list_models()
enable_midas_autodownload()
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 552c6c64..59982fc9 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -19,7 +19,8 @@ samplers_k_diffusion = [
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True, 'discard_next_to_last_sigma': True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
@@ -27,7 +28,8 @@ samplers_k_diffusion = [
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True, 'discard_next_to_last_sigma': True}),
]
samplers_data_k_diffusion = [
@@ -228,7 +230,7 @@ class KDiffusionSampler:
self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None
+ self.config = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
@@ -337,13 +339,13 @@ class KDiffusionSampler:
if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigma_sched
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- extra_args={
+ extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
@@ -373,7 +375,7 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
diff --git a/modules/shared.py b/modules/shared.py
index fa080458..3099d1d2 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -15,6 +15,7 @@ import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
+from typing import Optional
demo = None
@@ -113,7 +114,7 @@ class State:
time_start = None
server_start = None
_server_command_signal = threading.Event()
- _server_command: str | None = None
+ _server_command: Optional[str] = None
@property
def need_restart(self) -> bool:
@@ -131,14 +132,14 @@ class State:
return self._server_command
@server_command.setter
- def server_command(self, value: str | None) -> None:
+ def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
- def wait_for_server_command(self, timeout: float | None = None) -> str | None:
+ def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
@@ -417,6 +418,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {
+ "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
diff --git a/modules/shared_items.py b/modules/shared_items.py
index e792a134..2a8713c8 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -21,3 +21,11 @@ def refresh_vae_list():
import modules.sd_vae
modules.sd_vae.refresh_vae_list()
+
+
+def cross_attention_optimizations():
+ import modules.sd_hijack
+
+ return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+
+
diff --git a/modules/ui.py b/modules/ui.py
index 82820ab5..c0626587 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -272,12 +272,12 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
button_interrogate = None
button_deepbooru = None
@@ -505,10 +505,10 @@ def create_ui():
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
- hr_prompt = gr.Textbox(label="Prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.")
+ hr_prompt = gr.Textbox(label="Prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
- hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.")
+ hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
elif category == "batch":
if not opts.dimensions_and_batch_together:
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index ef18f438..515ec262 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -345,12 +345,12 @@ def install_extension_from_url(dirname, url, branch_name=None):
shutil.rmtree(tmpdir, True)
if not branch_name:
# if no branch is specified, use the default branch
- with git.Repo.clone_from(url, tmpdir) as repo:
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
else:
- with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 80cfa841..19fbaae5 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -232,10 +232,19 @@ class ExtraNetworksPage:
return None
-def intialize():
+def initialize():
extra_pages.clear()
+def register_default_pages():
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
+ from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
+ from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
+ register_page(ExtraNetworksPageTextualInversion())
+ register_page(ExtraNetworksPageHypernetworks())
+ register_page(ExtraNetworksPageCheckpoints())
+
+
class ExtraNetworksUi:
def __init__(self):
self.pages = None
diff --git a/webui.py b/webui.py
index b4a21e73..a76e377c 100644
--- a/webui.py
+++ b/webui.py
@@ -7,6 +7,7 @@ import re
import warnings
import json
from threading import Thread
+from typing import Iterable
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
@@ -14,6 +15,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
+
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import paths, timer, import_hook, errors # noqa: F401
@@ -34,8 +36,7 @@ startup_timer.record("import gradio")
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
-from modules import extra_networks, ui_extra_networks_checkpoints
-from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
+from modules import extra_networks
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
@@ -52,6 +53,7 @@ import modules.img2img
import modules.lowvram
import modules.scripts
import modules.sd_hijack
+import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
import modules.txt2img
@@ -162,13 +164,97 @@ def restore_config_state_file():
print(f"!!! Config state backup not found: {config_state_file}")
+def validate_tls_options():
+ if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
+ return
+
+ try:
+ if not os.path.exists(cmd_opts.tls_keyfile):
+ print("Invalid path to TLS keyfile given")
+ if not os.path.exists(cmd_opts.tls_certfile):
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+ except TypeError:
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+ print("TLS setup invalid, running webui without TLS")
+ else:
+ print("Running with TLS")
+ startup_timer.record("TLS")
+
+
+def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
+ """
+ Convert the gradio_auth and gradio_auth_path commandline arguments into
+ an iterable of (username, password) tuples.
+ """
+ def process_credential_line(s) -> tuple[str, ...] | None:
+ s = s.strip()
+ if not s:
+ return None
+ return tuple(s.split(':', 1))
+
+ if cmd_opts.gradio_auth:
+ for cred in cmd_opts.gradio_auth.split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+ if cmd_opts.gradio_auth_path:
+ with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ for cred in line.strip().split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+
+def configure_sigint_handler():
+ # make the program just exit at ctrl+c without waiting for anything
+ def sigint_handler(sig, frame):
+ print(f'Interrupted with signal {sig} in {frame}')
+ os._exit(0)
+
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
+
+
+def configure_opts_onchange():
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
+ shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ startup_timer.record("opts onchange")
+
+
def initialize():
fix_asyncio_event_loop_policy()
-
+ validate_tls_options()
+ configure_sigint_handler()
check_versions()
+ modelloader.cleanup_models()
+ configure_opts_onchange()
+
+ modules.sd_models.setup_model()
+ startup_timer.record("setup SD model")
+
+ codeformer.setup_model(cmd_opts.codeformer_models_path)
+ startup_timer.record("setup codeformer")
+
+ gfpgan.setup_model(cmd_opts.gfpgan_models_path)
+ startup_timer.record("setup gfpgan")
+
+ initialize_rest(reload_script_modules=False)
+
+def initialize_rest(*, reload_script_modules=False):
+ """
+ Called both from initialize() and when reloading the webui.
+ """
+ sd_samplers.set_samplers()
extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
restore_config_state_file()
@@ -178,85 +264,64 @@ def initialize():
modules.scripts.load_scripts()
return
- modelloader.cleanup_models()
- modules.sd_models.setup_model()
+ modules.sd_models.list_models()
startup_timer.record("list SD models")
- codeformer.setup_model(cmd_opts.codeformer_models_path)
- startup_timer.record("setup codeformer")
-
- gfpgan.setup_model(cmd_opts.gfpgan_models_path)
- startup_timer.record("setup gfpgan")
+ localization.list_localizations(cmd_opts.localizations_dir)
modules.scripts.load_scripts()
startup_timer.record("load scripts")
+ if reload_script_modules:
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+ startup_timer.record("reload script modules")
+
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
-
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
+ modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
+ modules.sd_hijack.list_optimizers()
+ startup_timer.record("scripts list_optimizers")
+
# load model in parallel to other startup stuff
+ # (when reloading, this does nothing)
Thread(target=lambda: shared.sd_model).start()
- shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
- shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
- shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
- startup_timer.record("opts onchange")
-
shared.reload_hypernetworks()
- startup_timer.record("reload hypernets")
+ startup_timer.record("reload hypernetworks")
- ui_extra_networks.intialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
+ ui_extra_networks.initialize()
+ ui_extra_networks.register_default_pages()
extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
- startup_timer.record("extra networks")
-
- if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
-
- try:
- if not os.path.exists(cmd_opts.tls_keyfile):
- print("Invalid path to TLS keyfile given")
- if not os.path.exists(cmd_opts.tls_certfile):
- print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
- except TypeError:
- cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
- print("TLS setup invalid, running webui without TLS")
- else:
- print("Running with TLS")
- startup_timer.record("TLS")
-
- # make the program just exit at ctrl+c without waiting for anything
- def sigint_handler(sig, frame):
- print(f'Interrupted with signal {sig} in {frame}')
- os._exit(0)
-
- if not os.environ.get("COVERAGE_RUN"):
- # Don't install the immediate-quit handler when running under coverage,
- # as then the coverage report won't be generated.
- signal.signal(signal.SIGINT, sigint_handler)
+ extra_networks.register_default_extra_networks()
+ startup_timer.record("initialize extra networks")
def setup_middleware(app):
- app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
- if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- app.build_middleware_stack() # rebuild middleware stack on-the-fly
+ configure_cors_middleware(app)
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
+
+
+def configure_cors_middleware(app):
+ cors_options = {
+ "allow_methods": ["*"],
+ "allow_headers": ["*"],
+ "allow_credentials": True,
+ }
+ if cmd_opts.cors_allow_origins:
+ cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
+ if cmd_opts.cors_allow_origins_regex:
+ cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
+ app.add_middleware(CORSMiddleware, **cors_options)
def create_api(app):
@@ -301,16 +366,11 @@ def webui():
if not cmd_opts.no_gradio_queue:
shared.demo.queue(64)
- gradio_auth_creds = []
- if cmd_opts.gradio_auth:
- gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
- if cmd_opts.gradio_auth_path:
- with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
- for line in file.readlines():
- gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+ gradio_auth_creds = list(get_gradio_auth_creds()) or None
# this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'):
+ # TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged
def fastapi_setup(self):
self.docs_url = "/docs"
self.redoc_url = "/redoc"
@@ -327,7 +387,7 @@ def webui():
ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug,
- auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
+ auth=gradio_auth_creds,
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path,
@@ -386,47 +446,16 @@ def webui():
print('Restarting UI...')
shared.demo.close()
time.sleep(0.5)
- modules.script_callbacks.app_reload_callback()
-
startup_timer.reset()
-
- sd_samplers.set_samplers()
-
+ modules.script_callbacks.app_reload_callback()
+ startup_timer.record("app reload callback")
modules.script_callbacks.script_unloaded_callback()
- extensions.list_extensions()
- startup_timer.record("list extensions")
-
- restore_config_state_file()
-
- localization.list_localizations(cmd_opts.localizations_dir)
-
- modules.scripts.reload_scripts()
- startup_timer.record("load scripts")
-
- modules.script_callbacks.model_loaded_callback(shared.sd_model)
- startup_timer.record("model loaded callback")
-
- modelloader.load_upscalers()
- startup_timer.record("load upscalers")
-
- for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
- importlib.reload(module)
- startup_timer.record("reload script modules")
-
- modules.sd_models.list_models()
- startup_timer.record("list SD models")
-
- shared.reload_hypernetworks()
- startup_timer.record("reload hypernetworks")
-
- ui_extra_networks.intialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
+ startup_timer.record("scripts unloaded callback")
+ initialize_rest(reload_script_modules=True)
- extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
- startup_timer.record("initialize extra networks")
+ modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
+ modules.sd_hijack.list_optimizers()
+ startup_timer.record("scripts list_optimizers")
if __name__ == "__main__":