aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py89
-rw-r--r--modules/extensions.py4
-rw-r--r--modules/generation_parameters_copypaste.py3
-rw-r--r--modules/gradio_extensons.py4
-rw-r--r--modules/images.py28
-rw-r--r--modules/img2img.py6
-rw-r--r--modules/initialize.py168
-rw-r--r--modules/initialize_util.py195
-rw-r--r--modules/launch_utils.py6
-rw-r--r--modules/localization.py3
-rw-r--r--modules/mac_specific.py4
-rw-r--r--modules/options.py236
-rw-r--r--modules/processing.py103
-rw-r--r--modules/rng.py170
-rw-r--r--modules/sd_hijack.py4
-rw-r--r--modules/sd_hijack_inpainting.py95
-rw-r--r--modules/sd_models.py23
-rw-r--r--modules/sd_models_config.py3
-rw-r--r--modules/sd_samplers.py19
-rw-r--r--modules/sd_samplers_cfg_denoiser.py203
-rw-r--r--modules/sd_samplers_common.py138
-rw-r--r--modules/sd_samplers_compvis.py224
-rw-r--r--modules/sd_samplers_kdiffusion.py351
-rw-r--r--modules/sd_samplers_timesteps.py151
-rw-r--r--modules/sd_samplers_timesteps_impl.py135
-rw-r--r--modules/sd_vae.py5
-rw-r--r--modules/shared.py962
-rw-r--r--modules/shared_cmd_options.py18
-rw-r--r--modules/shared_gradio_themes.py66
-rw-r--r--modules/shared_init.py49
-rw-r--r--modules/shared_items.py49
-rw-r--r--modules/shared_options.py316
-rw-r--r--modules/shared_state.py159
-rw-r--r--modules/shared_total_tqdm.py37
-rw-r--r--modules/sysinfo.py7
-rw-r--r--modules/txt2img.py8
-rw-r--r--modules/ui.py33
-rw-r--r--modules/ui_common.py4
-rw-r--r--modules/ui_extra_networks.py3
-rw-r--r--modules/ui_extra_networks_user_metadata.py4
-rw-r--r--modules/ui_tempdir.py5
-rw-r--r--modules/util.py58
42 files changed, 2311 insertions, 1837 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 00a00b18..c01f0602 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors, rng_philox
+from modules import errors, shared
if sys.platform == "darwin":
from modules import mac_specific
@@ -17,8 +17,6 @@ def has_mps() -> bool:
def get_cuda_device_string():
- from modules import shared
-
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -40,8 +38,6 @@ def get_optimal_device():
def get_device_for(task):
- from modules import shared
-
if task in shared.cmd_opts.use_cpu:
return cpu
@@ -96,87 +92,7 @@ def cond_cast_float(input):
nv_rng = None
-def randn(seed, shape):
- """Generate a tensor with random numbers from a normal distribution using seed.
-
- Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
-
- from modules.shared import opts
-
- manual_seed(seed)
-
- if opts.randn_source == "NV":
- return torch.asarray(nv_rng.randn(shape), device=device)
-
- if opts.randn_source == "CPU" or device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
-
- return torch.randn(shape, device=device)
-
-
-def randn_local(seed, shape):
- """Generate a tensor with random numbers from a normal distribution using seed.
-
- Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
-
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- rng = rng_philox.Generator(seed)
- return torch.asarray(rng.randn(shape), device=device)
-
- local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
- local_generator = torch.Generator(local_device).manual_seed(int(seed))
- return torch.randn(shape, device=local_device, generator=local_generator).to(device)
-
-
-def randn_like(x):
- """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
-
- Use either randn() or manual_seed() to initialize the generator."""
-
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
-
- if opts.randn_source == "CPU" or x.device.type == 'mps':
- return torch.randn_like(x, device=cpu).to(x.device)
-
- return torch.randn_like(x)
-
-
-def randn_without_seed(shape):
- """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
-
- Use either randn() or manual_seed() to initialize the generator."""
-
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- return torch.asarray(nv_rng.randn(shape), device=device)
-
- if opts.randn_source == "CPU" or device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
-
- return torch.randn(shape, device=device)
-
-
-def manual_seed(seed):
- """Set up a global random number generator using the specified seed."""
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- global nv_rng
- nv_rng = rng_philox.Generator(seed)
- return
-
- torch.manual_seed(seed)
-
-
def autocast(disable=False):
- from modules import shared
-
if disable:
return contextlib.nullcontext()
@@ -195,8 +111,6 @@ class NansException(Exception):
def test_for_nans(x, where):
- from modules import shared
-
if shared.cmd_opts.disable_nan_check:
return
@@ -236,3 +150,4 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
+
diff --git a/modules/extensions.py b/modules/extensions.py
index e4633af4..bf9a1878 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,7 +1,7 @@
import os
import threading
-from modules import shared, errors, cache
+from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -90,8 +90,6 @@ class Extension:
self.have_info_from_repo = True
def list_files(self, subdir, extension):
- from modules import scripts
-
dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath):
return []
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 5758e6f3..d932c67d 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -6,7 +6,7 @@ import re
import gradio as gr
from modules.paths import data_path
-from modules import shared, ui_tempdir, script_callbacks
+from modules import shared, ui_tempdir, script_callbacks, processing
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
@@ -198,7 +198,6 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
- from modules import processing
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
index 5af7fd8e..77c34c8b 100644
--- a/modules/gradio_extensons.py
+++ b/modules/gradio_extensons.py
@@ -1,6 +1,6 @@
import gradio as gr
-from modules import scripts
+from modules import scripts, ui_tempdir
def add_classes_to_gradio_component(comp):
"""
@@ -58,3 +58,5 @@ original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.components.IOComponent.__init__ = IOComponent_init
gr.blocks.Block.get_config = Block_get_config
gr.blocks.BlockContext.__init__ = BlockContext_init
+
+ui_tempdir.install_ui_tempdir_override()
diff --git a/modules/images.py b/modules/images.py
index ba3c43a4..019c1d60 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
-import modules.sd_vae as sd_vae
-
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator:
- def get_vae_filename(self): #get the name of the VAE file.
- if sd_vae.loaded_vae_file is None:
- return "NoneType"
- file_name = os.path.basename(sd_vae.loaded_vae_file)
- split_file_name = file_name.split('.')
- if len(split_file_name) > 1 and split_file_name[0] == '':
- return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
- else:
- return split_file_name[0]
-
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
@@ -391,6 +379,22 @@ class FilenameGenerator:
self.image = image
self.zip = zip
+ def get_vae_filename(self):
+ """Get the name of the VAE file."""
+
+ import modules.sd_vae as sd_vae
+
+ if sd_vae.loaded_vae_file is None:
+ return "NoneType"
+
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
+ split_file_name = file_name.split('.')
+ if len(split_file_name) > 1 and split_file_name[0] == '':
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
+ else:
+ return split_file_name[0]
+
+
def hasprompt(self, *args):
lower = self.prompt.lower()
if self.p is None or self.prompt is None:
diff --git a/modules/img2img.py b/modules/img2img.py
index d8e1c534..e06ac1d6 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -6,7 +6,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr
-from modules import sd_samplers, images as imgutil
+from modules import images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
@@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p)
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
+ sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
diff --git a/modules/initialize.py b/modules/initialize.py
new file mode 100644
index 00000000..f24f7637
--- /dev/null
+++ b/modules/initialize.py
@@ -0,0 +1,168 @@
+import importlib
+import logging
+import sys
+import warnings
+from threading import Thread
+
+from modules.timer import startup_timer
+
+
+def imports():
+ logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
+
+ import torch # noqa: F401
+ startup_timer.record("import torch")
+ import pytorch_lightning # noqa: F401
+ startup_timer.record("import torch")
+ warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
+
+ import gradio # noqa: F401
+ startup_timer.record("import gradio")
+
+ from modules import paths, timer, import_hook, errors # noqa: F401
+ startup_timer.record("setup paths")
+
+ import ldm.modules.encoders.modules # noqa: F401
+ startup_timer.record("import ldm")
+
+ import sgm.modules.encoders.modules # noqa: F401
+ startup_timer.record("import sgm")
+
+ from modules import shared_init
+ shared_init.initialize()
+ startup_timer.record("initialize shared")
+
+ from modules import processing, gradio_extensons, ui # noqa: F401
+ startup_timer.record("other imports")
+
+
+def check_versions():
+ from modules.shared_cmd_options import cmd_opts
+
+ if not cmd_opts.skip_version_check:
+ from modules import errors
+ errors.check_versions()
+
+
+def initialize():
+ from modules import initialize_util
+ initialize_util.fix_torch_version()
+ initialize_util.fix_asyncio_event_loop_policy()
+ initialize_util.validate_tls_options()
+ initialize_util.configure_sigint_handler()
+ initialize_util.configure_opts_onchange()
+
+ from modules import modelloader
+ modelloader.cleanup_models()
+
+ from modules import sd_models
+ sd_models.setup_model()
+ startup_timer.record("setup SD model")
+
+ from modules.shared_cmd_options import cmd_opts
+
+ from modules import codeformer_model
+ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
+ codeformer_model.setup_model(cmd_opts.codeformer_models_path)
+ startup_timer.record("setup codeformer")
+
+ from modules import gfpgan_model
+ gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
+ startup_timer.record("setup gfpgan")
+
+ initialize_rest(reload_script_modules=False)
+
+
+def initialize_rest(*, reload_script_modules=False):
+ """
+ Called both from initialize() and when reloading the webui.
+ """
+ from modules.shared_cmd_options import cmd_opts
+
+ from modules import sd_samplers
+ sd_samplers.set_samplers()
+ startup_timer.record("set samplers")
+
+ from modules import extensions
+ extensions.list_extensions()
+ startup_timer.record("list extensions")
+
+ from modules import initialize_util
+ initialize_util.restore_config_state_file()
+ startup_timer.record("restore config state file")
+
+ from modules import shared, upscaler, scripts
+ if cmd_opts.ui_debug_mode:
+ shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
+ scripts.load_scripts()
+ return
+
+ from modules import sd_models
+ sd_models.list_models()
+ startup_timer.record("list SD models")
+
+ from modules import localization
+ localization.list_localizations(cmd_opts.localizations_dir)
+ startup_timer.record("list localizations")
+
+ with startup_timer.subcategory("load scripts"):
+ scripts.load_scripts()
+
+ if reload_script_modules:
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+ startup_timer.record("reload script modules")
+
+ from modules import modelloader
+ modelloader.load_upscalers()
+ startup_timer.record("load upscalers")
+
+ from modules import sd_vae
+ sd_vae.refresh_vae_list()
+ startup_timer.record("refresh VAE")
+
+ from modules import textual_inversion
+ textual_inversion.textual_inversion.list_textual_inversion_templates()
+ startup_timer.record("refresh textual inversion templates")
+
+ from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
+ script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
+ sd_hijack.list_optimizers()
+ startup_timer.record("scripts list_optimizers")
+
+ from modules import sd_unet
+ sd_unet.list_unets()
+ startup_timer.record("scripts list_unets")
+
+ def load_model():
+ """
+ Accesses shared.sd_model property to load model.
+ After it's available, if it has been loaded before this access by some extension,
+ its optimization may be None because the list of optimizaers has neet been filled
+ by that time, so we apply optimization again.
+ """
+
+ shared.sd_model # noqa: B018
+
+ if sd_hijack.current_optimizer is None:
+ sd_hijack.apply_optimizations()
+
+ from modules import devices
+ devices.first_time_calculation()
+
+ Thread(target=load_model).start()
+
+ from modules import shared_items
+ shared_items.reload_hypernetworks()
+ startup_timer.record("reload hypernetworks")
+
+ from modules import ui_extra_networks
+ ui_extra_networks.initialize()
+ ui_extra_networks.register_default_pages()
+
+ from modules import extra_networks
+ extra_networks.initialize()
+ extra_networks.register_default_extra_networks()
+ startup_timer.record("initialize extra networks")
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
new file mode 100644
index 00000000..e59bd3c4
--- /dev/null
+++ b/modules/initialize_util.py
@@ -0,0 +1,195 @@
+import json
+import logging
+import os
+import signal
+import sys
+import re
+
+from modules.timer import startup_timer
+
+def setup_logging():
+ # We can't use cmd_opts for this because it will not have been initialized at this point.
+ log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
+ if log_level:
+ log_level = getattr(logging, log_level.upper(), None) or logging.INFO
+ logging.basicConfig(
+ level=log_level,
+ format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ )
+
+
+def gradio_server_name():
+ from modules.shared_cmd_options import cmd_opts
+
+ if cmd_opts.server_name:
+ return cmd_opts.server_name
+ else:
+ return "0.0.0.0" if cmd_opts.listen else None
+
+
+def fix_torch_version():
+ import torch
+
+ # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
+ if ".dev" in torch.__version__ or "+git" in torch.__version__:
+ torch.__long_version__ = torch.__version__
+ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
+
+
+def fix_asyncio_event_loop_policy():
+ """
+ The default `asyncio` event loop policy only automatically creates
+ event loops in the main threads. Other threads must create event
+ loops explicitly or `asyncio.get_event_loop` (and therefore
+ `.IOLoop.current`) will fail. Installing this policy allows event
+ loops to be created automatically on any thread, matching the
+ behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
+ """
+
+ import asyncio
+
+ if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
+ # "Any thread" and "selector" should be orthogonal, but there's not a clean
+ # interface for composing policies so pick the right base.
+ _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
+ else:
+ _BasePolicy = asyncio.DefaultEventLoopPolicy
+
+ class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
+ """Event loop policy that allows loop creation on any thread.
+ Usage::
+
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
+ """
+
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
+ try:
+ return super().get_event_loop()
+ except (RuntimeError, AssertionError):
+ # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
+ # and changed to a RuntimeError in 3.4.3.
+ # "There is no current event loop in thread %r"
+ loop = self.new_event_loop()
+ self.set_event_loop(loop)
+ return loop
+
+ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
+
+
+def restore_config_state_file():
+ from modules import shared, config_states
+
+ config_state_file = shared.opts.restore_config_state_file
+ if config_state_file == "":
+ return
+
+ shared.opts.restore_config_state_file = ""
+ shared.opts.save(shared.config_filename)
+
+ if os.path.isfile(config_state_file):
+ print(f"*** About to restore extension state from file: {config_state_file}")
+ with open(config_state_file, "r", encoding="utf-8") as f:
+ config_state = json.load(f)
+ config_states.restore_extension_config(config_state)
+ startup_timer.record("restore extension config")
+ elif config_state_file:
+ print(f"!!! Config state backup not found: {config_state_file}")
+
+
+def validate_tls_options():
+ from modules.shared_cmd_options import cmd_opts
+
+ if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
+ return
+
+ try:
+ if not os.path.exists(cmd_opts.tls_keyfile):
+ print("Invalid path to TLS keyfile given")
+ if not os.path.exists(cmd_opts.tls_certfile):
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+ except TypeError:
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+ print("TLS setup invalid, running webui without TLS")
+ else:
+ print("Running with TLS")
+ startup_timer.record("TLS")
+
+
+def get_gradio_auth_creds():
+ """
+ Convert the gradio_auth and gradio_auth_path commandline arguments into
+ an iterable of (username, password) tuples.
+ """
+ from modules.shared_cmd_options import cmd_opts
+
+ def process_credential_line(s):
+ s = s.strip()
+ if not s:
+ return None
+ return tuple(s.split(':', 1))
+
+ if cmd_opts.gradio_auth:
+ for cred in cmd_opts.gradio_auth.split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+ if cmd_opts.gradio_auth_path:
+ with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ for cred in line.strip().split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+
+def configure_sigint_handler():
+ # make the program just exit at ctrl+c without waiting for anything
+ def sigint_handler(sig, frame):
+ print(f'Interrupted with signal {sig} in {frame}')
+ os._exit(0)
+
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
+
+
+def configure_opts_onchange():
+ from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
+ from modules.call_queue import wrap_queued_call
+
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
+ shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ startup_timer.record("opts onchange")
+
+
+def setup_middleware(app):
+ from starlette.middleware.gzip import GZipMiddleware
+
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
+ configure_cors_middleware(app)
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
+
+
+def configure_cors_middleware(app):
+ from starlette.middleware.cors import CORSMiddleware
+ from modules.shared_cmd_options import cmd_opts
+
+ cors_options = {
+ "allow_methods": ["*"],
+ "allow_headers": ["*"],
+ "allow_credentials": True,
+ }
+ if cmd_opts.cors_allow_origins:
+ cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
+ if cmd_opts.cors_allow_origins_regex:
+ cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
+ app.add_middleware(CORSMiddleware, **cors_options)
+
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 5be30a18..7143f144 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -248,7 +248,11 @@ def run_extensions_installers(settings_file):
return
with startup_timer.subcategory("run extensions installers"):
- for dirname_extension in list_extensions(settings_file):
+ import tqdm
+ progress_bar = tqdm.tqdm(list_extensions(settings_file))
+ for dirname_extension in progress_bar:
+ progress_bar.set_description(f"Installing {dirname_extension}")
+
path = os.path.join(extensions_dir, dirname_extension)
if os.path.isdir(path):
diff --git a/modules/localization.py b/modules/localization.py
index e8f585da..c1320288 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,7 +1,7 @@
import json
import os
-from modules import errors
+from modules import errors, scripts
localizations = {}
@@ -16,7 +16,6 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file)
- from modules import scripts
for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename)
localizations[fn] = file.path
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 9ceb43ba..bce527cc 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -4,6 +4,7 @@ import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+from modules import shared
log = logging.getLogger(__name__)
@@ -30,8 +31,7 @@ has_mps = check_for_mps()
def torch_mps_gc() -> None:
try:
- from modules.shared import state
- if state.current_latent is not None:
+ if shared.state.current_latent is not None:
log.debug("`current_latent` is set, skipping MPS garbage collection")
return
from torch.mps import empty_cache
diff --git a/modules/options.py b/modules/options.py
new file mode 100644
index 00000000..59cb75ec
--- /dev/null
+++ b/modules/options.py
@@ -0,0 +1,236 @@
+import json
+import sys
+
+import gradio as gr
+
+from modules import errors
+from modules.shared_cmd_options import cmd_opts
+
+
+class OptionInfo:
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
+ self.default = default
+ self.label = label
+ self.component = component
+ self.component_args = component_args
+ self.onchange = onchange
+ self.section = section
+ self.refresh = refresh
+ self.do_not_save = False
+
+ self.comment_before = comment_before
+ """HTML text that will be added after label in UI"""
+
+ self.comment_after = comment_after
+ """HTML text that will be added before label in UI"""
+
+ def link(self, label, url):
+ self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
+ return self
+
+ def js(self, label, js_func):
+ self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
+ return self
+
+ def info(self, info):
+ self.comment_after += f"<span class='info'>({info})</span>"
+ return self
+
+ def html(self, html):
+ self.comment_after += html
+ return self
+
+ def needs_restart(self):
+ self.comment_after += " <span class='info'>(requires restart)</span>"
+ return self
+
+ def needs_reload_ui(self):
+ self.comment_after += " <span class='info'>(requires Reload UI)</span>"
+ return self
+
+
+class OptionHTML(OptionInfo):
+ def __init__(self, text):
+ super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
+
+ self.do_not_save = True
+
+
+def options_section(section_identifier, options_dict):
+ for v in options_dict.values():
+ v.section = section_identifier
+
+ return options_dict
+
+
+options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
+
+
+class Options:
+ typemap = {int: float}
+
+ def __init__(self, data_labels, restricted_opts):
+ self.data_labels = data_labels
+ self.data = {k: v.default for k, v in self.data_labels.items()}
+ self.restricted_opts = restricted_opts
+
+ def __setattr__(self, key, value):
+ if key in options_builtin_fields:
+ return super(Options, self).__setattr__(key, value)
+
+ if self.data is not None:
+ if key in self.data or key in self.data_labels:
+ assert not cmd_opts.freeze_settings, "changing settings is disabled"
+
+ info = self.data_labels.get(key, None)
+ if info.do_not_save:
+ return
+
+ comp_args = info.component_args if info else None
+ if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
+ if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
+ self.data[key] = value
+ return
+
+ return super(Options, self).__setattr__(key, value)
+
+ def __getattr__(self, item):
+ if item in options_builtin_fields:
+ return super(Options, self).__getattribute__(item)
+
+ if self.data is not None:
+ if item in self.data:
+ return self.data[item]
+
+ if item in self.data_labels:
+ return self.data_labels[item].default
+
+ return super(Options, self).__getattribute__(item)
+
+ def set(self, key, value):
+ """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
+
+ oldval = self.data.get(key, None)
+ if oldval == value:
+ return False
+
+ if self.data_labels[key].do_not_save:
+ return False
+
+ try:
+ setattr(self, key, value)
+ except RuntimeError:
+ return False
+
+ if self.data_labels[key].onchange is not None:
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ errors.display(e, f"changing setting {key} to {value}")
+ setattr(self, key, oldval)
+ return False
+
+ return True
+
+ def get_default(self, key):
+ """returns the default value for the key"""
+
+ data_label = self.data_labels.get(key)
+ if data_label is None:
+ return None
+
+ return data_label.default
+
+ def save(self, filename):
+ assert not cmd_opts.freeze_settings, "saving settings is disabled"
+
+ with open(filename, "w", encoding="utf8") as file:
+ json.dump(self.data, file, indent=4)
+
+ def same_type(self, x, y):
+ if x is None or y is None:
+ return True
+
+ type_x = self.typemap.get(type(x), type(x))
+ type_y = self.typemap.get(type(y), type(y))
+
+ return type_x == type_y
+
+ def load(self, filename):
+ with open(filename, "r", encoding="utf8") as file:
+ self.data = json.load(file)
+
+ # 1.6.0 VAE defaults
+ if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
+ self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
+
+ # 1.1.1 quicksettings list migration
+ if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
+ self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
+
+ # 1.4.0 ui_reorder
+ if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
+ self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
+
+ bad_settings = 0
+ for k, v in self.data.items():
+ info = self.data_labels.get(k, None)
+ if info is not None and not self.same_type(info.default, v):
+ print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
+ bad_settings += 1
+
+ if bad_settings > 0:
+ print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
+
+ def onchange(self, key, func, call=True):
+ item = self.data_labels.get(key)
+ item.onchange = func
+
+ if call:
+ func()
+
+ def dumpjson(self):
+ d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
+ d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
+ d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
+ return json.dumps(d)
+
+ def add_option(self, key, info):
+ self.data_labels[key] = info
+
+ def reorder(self):
+ """reorder settings so that all items related to section always go together"""
+
+ section_ids = {}
+ settings_items = self.data_labels.items()
+ for _, item in settings_items:
+ if item.section not in section_ids:
+ section_ids[item.section] = len(section_ids)
+
+ self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
+
+ def cast_value(self, key, value):
+ """casts an arbitrary to the same type as this setting's value with key
+ Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
+ """
+
+ if value is None:
+ return None
+
+ default_value = self.data_labels[key].default
+ if default_value is None:
+ default_value = getattr(self, key, None)
+ if default_value is None:
+ return None
+
+ expected_type = type(default_value)
+ if expected_type == bool and value == "False":
+ value = False
+ else:
+ value = expected_type(value)
+
+ return value
diff --git a/modules/processing.py b/modules/processing.py
index 31745006..2df5e8c7 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -14,7 +14,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state
@@ -172,6 +172,8 @@ class StableDiffusionProcessing:
self.iteration = 0
self.is_hr_pass = False
self.sampler = None
+ self.main_prompt = None
+ self.main_negative_prompt = None
self.prompts = None
self.negative_prompts = None
@@ -184,6 +186,7 @@ class StableDiffusionProcessing:
self.cached_c = StableDiffusionProcessing.cached_c
self.uc = None
self.c = None
+ self.rng: rng.ImageRNG = None
self.user = None
@@ -319,6 +322,9 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+ self.main_prompt = self.all_prompts[0]
+ self.main_negative_prompt = self.all_negative_prompts[0]
+
def cached_params(self, required_prompts, steps, extra_network_data):
"""Returns parameters that invalidate the cond cache if changed"""
@@ -470,82 +476,9 @@ class Processed:
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
-# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
-def slerp(val, low, high):
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- dot = (low_norm*high_norm).sum(1)
-
- if dot.mean() > 0.9995:
- return low * val + high * (1 - val)
-
- omega = torch.acos(dot)
- so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
- return res
-
-
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
- eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
- xs = []
-
- # if we have multiple seeds, this means we are working with batch size>1; this then
- # enables the generation of additional tensors with noise that the sampler will use during its processing.
- # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
- # produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
- sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
- else:
- sampler_noises = None
-
- for i, seed in enumerate(seeds):
- noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
-
- subnoise = None
- if subseeds is not None and subseed_strength != 0:
- subseed = 0 if i >= len(subseeds) else subseeds[i]
-
- subnoise = devices.randn(subseed, noise_shape)
-
- # randn results depend on device; gpu and cpu get different results for same seed;
- # the way I see it, it's better to do this on CPU, so that everyone gets same result;
- # but the original script had it like this, so I do not dare change it for now because
- # it will break everyone's seeds.
- noise = devices.randn(seed, noise_shape)
-
- if subnoise is not None:
- noise = slerp(subseed_strength, noise, subnoise)
-
- if noise_shape != shape:
- x = devices.randn(seed, shape)
- dx = (shape[2] - noise_shape[2]) // 2
- dy = (shape[1] - noise_shape[1]) // 2
- w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
- h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
- tx = 0 if dx < 0 else dx
- ty = 0 if dy < 0 else dy
- dx = max(-dx, 0)
- dy = max(-dy, 0)
-
- x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
- noise = x
-
- if sampler_noises is not None:
- cnt = p.sampler.number_of_needed_noises(p)
-
- if eta_noise_seed_delta > 0:
- devices.manual_seed(seed + eta_noise_seed_delta)
-
- for j in range(cnt):
- sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
-
- xs.append(noise)
-
- if sampler_noises is not None:
- p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
-
- x = torch.stack(xs).to(shared.device)
- return x
+ g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
+ return g.next()
class DecodedSamples(list):
@@ -568,7 +501,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
"Web UI will now convert VAE into 32-bit float and retry.\n"
- "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
+ "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
)
@@ -653,8 +586,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- prompt_text = p.prompt if use_main_prompt else all_prompts[index]
- negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
+ prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
+ negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -764,6 +697,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
+
if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@@ -1067,7 +1002,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = self.rng.next()
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
@@ -1112,9 +1047,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
- if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
- img2img_sampler_name = 'DDIM'
-
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
if self.latent_scale_mode is not None:
@@ -1158,7 +1090,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
+ self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
+ noise = self.rng.next()
# GC now before running the next img2img to prevent running out of memory
devices.torch_gc()
@@ -1416,7 +1349,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = self.rng.next()
if self.initial_noise_multiplier != 1.0:
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
diff --git a/modules/rng.py b/modules/rng.py
new file mode 100644
index 00000000..f927a318
--- /dev/null
+++ b/modules/rng.py
@@ -0,0 +1,170 @@
+import torch
+
+from modules import devices, rng_philox, shared
+
+
+def randn(seed, shape, generator=None):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
+ manual_seed(seed)
+
+ if shared.opts.randn_source == "NV":
+ return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
+
+ if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
+ return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
+
+ return torch.randn(shape, device=devices.device, generator=generator)
+
+
+def randn_local(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+ if shared.opts.randn_source == "NV":
+ rng = rng_philox.Generator(seed)
+ return torch.asarray(rng.randn(shape), device=devices.device)
+
+ local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
+ local_generator = torch.Generator(local_device).manual_seed(int(seed))
+ return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
+
+
+def randn_like(x):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
+ if shared.opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+ if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
+
+ return torch.randn_like(x)
+
+
+def randn_without_seed(shape, generator=None):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
+ if shared.opts.randn_source == "NV":
+ return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
+
+ if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
+ return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
+
+ return torch.randn(shape, device=devices.device, generator=generator)
+
+
+def manual_seed(seed):
+ """Set up a global random number generator using the specified seed."""
+
+ if shared.opts.randn_source == "NV":
+ global nv_rng
+ nv_rng = rng_philox.Generator(seed)
+ return
+
+ torch.manual_seed(seed)
+
+
+def create_generator(seed):
+ if shared.opts.randn_source == "NV":
+ return rng_philox.Generator(seed)
+
+ device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return generator
+
+
+# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
+def slerp(val, low, high):
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
+ dot = (low_norm*high_norm).sum(1)
+
+ if dot.mean() > 0.9995:
+ return low * val + high * (1 - val)
+
+ omega = torch.acos(dot)
+ so = torch.sin(omega)
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ return res
+
+
+class ImageRNG:
+ def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
+ self.shape = shape
+ self.seeds = seeds
+ self.subseeds = subseeds
+ self.subseed_strength = subseed_strength
+ self.seed_resize_from_h = seed_resize_from_h
+ self.seed_resize_from_w = seed_resize_from_w
+
+ self.generators = [create_generator(seed) for seed in seeds]
+
+ self.is_first = True
+
+ def first(self):
+ noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
+
+ xs = []
+
+ for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
+ subnoise = None
+ if self.subseeds is not None and self.subseed_strength != 0:
+ subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
+ subnoise = randn(subseed, noise_shape)
+
+ if noise_shape != self.shape:
+ noise = randn(seed, noise_shape)
+ else:
+ noise = randn(seed, self.shape, generator=generator)
+
+ if subnoise is not None:
+ noise = slerp(self.subseed_strength, noise, subnoise)
+
+ if noise_shape != self.shape:
+ x = randn(seed, self.shape, generator=generator)
+ dx = (self.shape[2] - noise_shape[2]) // 2
+ dy = (self.shape[1] - noise_shape[1]) // 2
+ w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
+ h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
+ tx = 0 if dx < 0 else dx
+ ty = 0 if dy < 0 else dy
+ dx = max(-dx, 0)
+ dy = max(-dy, 0)
+
+ x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
+ noise = x
+
+ xs.append(noise)
+
+ eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
+ if eta_noise_seed_delta:
+ self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
+
+ return torch.stack(xs).to(shared.device)
+
+ def next(self):
+ if self.is_first:
+ self.is_first = False
+ return self.first()
+
+ xs = []
+ for generator in self.generators:
+ x = randn_without_seed(self.shape, generator=generator)
+ xs.append(x)
+
+ return torch.stack(xs).to(shared.device)
+
+
+devices.randn = randn
+devices.randn_local = randn_local
+devices.randn_like = randn_like
+devices.randn_without_seed = randn_without_seed
+devices.manual_seed = manual_seed
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 9ad98199..46652fbd 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
-sd_hijack_inpainting.do_inpainting_hijack()
-
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
deleted file mode 100644
index 2d44b856..00000000
--- a/modules/sd_hijack_inpainting.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import torch
-
-import ldm.models.diffusion.ddpm
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-
-from ldm.models.diffusion.ddim import noise_like
-from ldm.models.diffusion.sampling_util import norm_thresholding
-
-
-@torch.no_grad()
-def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
- b, *_, device = *x.shape, x.device
-
- def get_model_output(x, t):
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
-
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = {}
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- else:
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- c_in = torch.cat([unconditional_conditioning, c])
-
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- return e_t
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
-
- def get_x_prev_and_pred_x0(e_t, index):
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- if dynamic_threshold is not None:
- pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- e_t = get_model_output(x, t)
- if len(old_eps) == 0:
- # Pseudo Improved Euler (2nd order)
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
- e_t_next = get_model_output(x_prev, t_next)
- e_t_prime = (e_t + e_t_next) / 2
- elif len(old_eps) == 1:
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
- elif len(old_eps) == 2:
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
- elif len(old_eps) >= 3:
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
-
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
-
- return x_prev, pred_x0, e_t
-
-
-def do_inpainting_hijack():
- ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 53c1df54..7a866a07 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules.timer import Timer
import tomesd
@@ -68,7 +68,9 @@ class CheckpointInfo:
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
- self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
+ if self.shorthash:
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
def register(self):
checkpoints_list[self.title] = self
@@ -80,10 +82,14 @@ class CheckpointInfo:
if self.sha256 is None:
return
- self.shorthash = self.sha256[0:10]
+ shorthash = self.sha256[0:10]
+ if self.shorthash == self.sha256[0:10]:
+ return self.shorthash
+
+ self.shorthash = shorthash
if self.shorthash not in self.ids:
- self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
@@ -473,7 +479,6 @@ model_data = SdModelData()
def get_empty_cond(sd_model):
- from modules import extra_networks, processing
p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {})
@@ -486,8 +491,6 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
- from modules import lowvram
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
@@ -497,8 +500,6 @@ def send_model_to_cpu(m):
def send_model_to_device(m):
- from modules import lowvram
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
else:
@@ -623,6 +624,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
timer.record("send model to device")
model_data.set_sd_model(already_loaded)
+ shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
+ shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
@@ -642,7 +645,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
def reload_model_weights(sd_model=None, info=None):
- from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -705,7 +707,6 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 8266fa39..08dd03f1 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -2,7 +2,7 @@ import os
import torch
-from modules import shared, paths, sd_disable_initialization
+from modules import shared, paths, sd_disable_initialization, devices
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
@@ -29,7 +29,6 @@ def is_using_v_parameterization_for_sd2(state_dict):
"""
import ldm.modules.diffusionmodules.openaimodel
- from modules import devices
device = devices.cpu
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index bea2684c..45faae62 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,17 +1,18 @@
-from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
+from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
- *sd_samplers_compvis.samplers_data_compvis,
+ *sd_samplers_timesteps.samplers_data_timesteps,
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers_map = {}
+samplers_hidden = {}
def find_sampler_config(name):
@@ -38,13 +39,11 @@ def create_sampler(name, model):
def set_samplers():
- global samplers, samplers_for_img2img
+ global samplers, samplers_for_img2img, samplers_hidden
- hidden = set(shared.opts.hide_samplers)
- hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
-
- samplers = [x for x in all_samplers if x.name not in hidden]
- samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
+ samplers_hidden = set(shared.opts.hide_samplers)
+ samplers = all_samplers
+ samplers_for_img2img = all_samplers
samplers_map.clear()
for sampler in all_samplers:
@@ -53,4 +52,8 @@ def set_samplers():
samplers_map[alias.lower()] = sampler.name
+def visible_sampler_names():
+ return [x.name for x in samplers if x.name not in samplers_hidden]
+
+
set_samplers()
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
new file mode 100644
index 00000000..d826222c
--- /dev/null
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -0,0 +1,203 @@
+import torch
+from modules import prompt_parser, devices, sd_samplers_common
+
+from modules.shared import opts, state
+import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
+
+
+def catenate_conds(conds):
+ if not isinstance(conds[0], dict):
+ return torch.cat(conds)
+
+ return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
+
+
+def subscript_cond(cond, a, b):
+ if not isinstance(cond, dict):
+ return cond[a:b]
+
+ return {key: vec[a:b] for key, vec in cond.items()}
+
+
+def pad_cond(tensor, repeats, empty):
+ if not isinstance(tensor, dict):
+ return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
+
+ tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
+ return tensor
+
+
+class CFGDenoiser(torch.nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+
+ def __init__(self, model, sampler):
+ super().__init__()
+ self.inner_model = model
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.step = 0
+ self.image_cfg_scale = None
+ self.padded_cond_uncond = False
+ self.sampler = sampler
+
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
+
+ return denoised
+
+ def get_pred_x0(self, x_in, x_out, sigma):
+ return x_out
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+
+ if self.mask is not None:
+ x = self.init_latent * self.mask + self.nmask * x
+
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+
+ if shared.sd_model.model.conditioning_key == "crossattn-adm":
+ image_uncond = torch.zeros_like(image_cond)
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
+ else:
+ image_uncond = image_cond
+ if isinstance(uncond, dict):
+ make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
+ else:
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
+
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
+
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+ tensor = denoiser_params.text_cond
+ uncond = denoiser_params.text_uncond
+ skip_uncond = False
+
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ self.padded_cond_uncond = False
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = pad_cond(tensor, -num_repeats, empty)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = pad_cond(uncond, num_repeats, empty)
+ self.padded_cond_uncond = True
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
+ cond_in = catenate_conds([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = catenate_conds([tensor, uncond])
+
+ if shared.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
+ else:
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+
+ if not is_edit_model:
+ c_crossattn = subscript_cond(tensor, a, b)
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
+
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
+
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
+ cfg_denoised_callback(denoised_params)
+
+ devices.test_for_nans(x_out, "unet")
+
+ if is_edit_model:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+
+ self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
+
+ if opts.live_preview_content == "Prompt":
+ preview = self.sampler.last_latent
+ elif opts.live_preview_content == "Negative prompt":
+ preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
+ else:
+ preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
+
+ sd_samplers_common.store_latent(preview)
+
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
+ self.step += 1
+ return denoised
+
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 39586b40..97bc0804 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -1,9 +1,11 @@
+import inspect
from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
+import k_diffusion.sampling
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -127,3 +129,139 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
+
+
+class TorchHijack:
+ """This is here to replace torch.randn_like of k-diffusion.
+
+ k-diffusion has random_sampler argument for most samplers, but not for all, so
+ this is needed to properly replace every use of torch.randn_like.
+
+ We need to replace to make images generated in batches to be same as images generated individually."""
+
+ def __init__(self, p):
+ self.rng = p.rng
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.randn_like
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
+
+ def randn_like(self, x):
+ return self.rng.next()
+
+
+class Sampler:
+ def __init__(self, funcname):
+ self.funcname = funcname
+ self.func = funcname
+ self.extra_params = []
+ self.sampler_noises = None
+ self.stop_at = None
+ self.eta = None
+ self.config = None # set by the function calling the constructor
+ self.last_latent = None
+ self.s_min_uncond = None
+ self.s_churn = 0.0
+ self.s_tmin = 0.0
+ self.s_tmax = float('inf')
+ self.s_noise = 1.0
+
+ self.eta_option_field = 'eta_ancestral'
+ self.eta_infotext_field = 'Eta'
+
+ self.conditioning_key = shared.sd_model.model.conditioning_key
+
+ self.model_wrap = None
+ self.model_wrap_cfg = None
+
+ def callback_state(self, d):
+ step = d['i']
+
+ if self.stop_at is not None and step > self.stop_at:
+ raise InterruptedException
+
+ state.sampling_step = step
+ shared.total_tqdm.update()
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
+ except InterruptedException:
+ return self.last_latent
+
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def initialize(self, p) -> dict:
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.step = 0
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
+ self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
+
+ k_diffusion.sampling.torch = TorchHijack(p)
+
+ extra_params_kwargs = {}
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
+
+ if 'eta' in inspect.signature(self.func).parameters:
+ if self.eta != 1.0:
+ p.extra_generation_params[self.eta_infotext_field] = self.eta
+
+ extra_params_kwargs['eta'] = self.eta
+
+ if len(self.extra_params) > 0:
+ s_churn = getattr(opts, 's_churn', p.s_churn)
+ s_tmin = getattr(opts, 's_tmin', p.s_tmin)
+ s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
+ s_noise = getattr(opts, 's_noise', p.s_noise)
+
+ if s_churn != self.s_churn:
+ extra_params_kwargs['s_churn'] = s_churn
+ p.s_churn = s_churn
+ p.extra_generation_params['Sigma churn'] = s_churn
+ if s_tmin != self.s_tmin:
+ extra_params_kwargs['s_tmin'] = s_tmin
+ p.s_tmin = s_tmin
+ p.extra_generation_params['Sigma tmin'] = s_tmin
+ if s_tmax != self.s_tmax:
+ extra_params_kwargs['s_tmax'] = s_tmax
+ p.s_tmax = s_tmax
+ p.extra_generation_params['Sigma tmax'] = s_tmax
+ if s_noise != self.s_noise:
+ extra_params_kwargs['s_noise'] = s_noise
+ p.s_noise = s_noise
+ p.extra_generation_params['Sigma noise'] = s_noise
+
+ return extra_params_kwargs
+
+ def create_noise_sampler(self, x, sigmas, p):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ if shared.opts.no_dpmpp_sde_batch_determinism:
+ return None
+
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+
+
+
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
deleted file mode 100644
index 4a8396f9..00000000
--- a/modules/sd_samplers_compvis.py
+++ /dev/null
@@ -1,224 +0,0 @@
-import math
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-
-import numpy as np
-import torch
-
-from modules.shared import state
-from modules import sd_samplers_common, prompt_parser, shared
-import modules.models.diffusion.uni_pc
-
-
-samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
- sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
- sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
-]
-
-
-class VanillaStableDiffusionSampler:
- def __init__(self, constructor, sd_model):
- self.sampler = constructor(sd_model)
- self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
- self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
- self.orig_p_sample_ddim = None
- if self.is_plms:
- self.orig_p_sample_ddim = self.sampler.p_sample_plms
- elif self.is_ddim:
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.sampler_noises = None
- self.step = 0
- self.stop_at = None
- self.eta = None
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def number_of_needed_noises(self, p):
- return 0
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
- x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
-
- res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
-
- x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
-
- return res
-
- def before_sample(self, x, ts, cond, unconditional_conditioning):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
-
- if self.stop_at is not None and self.step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- # Have to unwrap the inpainting conditioning here to perform pre-processing
- image_conditioning = None
- uc_image_conditioning = None
- if isinstance(cond, dict):
- if self.conditioning_key == "crossattn-adm":
- image_conditioning = cond["c_adm"]
- uc_image_conditioning = unconditional_conditioning["c_adm"]
- else:
- image_conditioning = cond["c_concat"][0]
- cond = cond["c_crossattn"][0]
- unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
-
- assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
- cond = tensor
-
- # for DDIM, shapes must match, we can't just process cond and uncond independently;
- # filling unconditional_conditioning with repeats of the last vector to match length is
- # not 100% correct but should work well enough
- if unconditional_conditioning.shape[1] < cond.shape[1]:
- last_vector = unconditional_conditioning[:, -1:]
- last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
- unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
- elif unconditional_conditioning.shape[1] > cond.shape[1]:
- unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
-
- if self.mask is not None:
- img_orig = self.sampler.model.q_sample(self.init_latent, ts)
- x = img_orig * self.mask + self.nmask * x
-
- # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
- # Note that they need to be lists because it just concatenates them later.
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
- unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
- else:
- cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- return x, ts, cond, unconditional_conditioning
-
- def update_step(self, last_latent):
- if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
- else:
- self.last_latent = last_latent
-
- sd_samplers_common.store_latent(self.last_latent)
-
- self.step += 1
- state.sampling_step = self.step
- shared.total_tqdm.update()
-
- def after_sample(self, x, ts, cond, uncond, res):
- if not self.is_unipc:
- self.update_step(res[1])
-
- return x, ts, cond, uncond, res
-
- def unipc_after_update(self, x, model_x):
- self.update_step(x)
-
- def initialize(self, p):
- if self.is_ddim:
- self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
- else:
- self.eta = 0.0
-
- if self.eta != 0.0:
- p.extra_generation_params["Eta DDIM"] = self.eta
-
- if self.is_unipc:
- keys = [
- ('UniPC variant', 'uni_pc_variant'),
- ('UniPC skip type', 'uni_pc_skip_type'),
- ('UniPC order', 'uni_pc_order'),
- ('UniPC lower order final', 'uni_pc_lower_order_final'),
- ]
-
- for name, key in keys:
- v = getattr(shared.opts, key)
- if v != shared.opts.get_default(key):
- p.extra_generation_params[name] = v
-
- for fieldname in ['p_sample_ddim', 'p_sample_plms']:
- if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
- if self.is_unipc:
- self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
-
- self.mask = p.mask if hasattr(p, 'mask') else None
- self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
-
- def adjust_steps_if_invalid(self, p, num_steps):
- if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
- if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
- num_steps = shared.opts.uni_pc_order
- valid_step = 999 / (1000 // num_steps)
- if valid_step == math.floor(valid_step):
- return int(valid_step) + 1
-
- return num_steps
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
- steps = self.adjust_steps_if_invalid(p, steps)
- self.initialize(p)
-
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
-
- self.init_latent = x
- self.last_latent = x
- self.step = 0
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
- else:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- self.initialize(p)
-
- self.init_latent = None
- self.last_latent = x
- self.step = 0
-
- steps = self.adjust_steps_if_invalid(p, steps or p.steps)
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
- if image_conditioning is not None:
- if self.conditioning_key == "crossattn-adm":
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
- else:
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
-
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
-
- return samples_ddim
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index db71a549..5613b8c1 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,17 +1,17 @@
-from collections import deque
import torch
import inspect
import k_diffusion.sampling
-from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
+from modules import sd_samplers_common, sd_samplers_extra
+from modules.sd_samplers_cfg_denoiser import CFGDenoiser
-from modules.processing import StableDiffusionProcessing
-from modules.shared import opts, state
+from modules.shared import opts
import modules.shared as shared
-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
-from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
+ ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
@@ -28,10 +28,6 @@ samplers_k_diffusion = [
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
- ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
- ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
- ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
]
@@ -57,317 +53,17 @@ k_diffusion_scheduler = {
}
-def catenate_conds(conds):
- if not isinstance(conds[0], dict):
- return torch.cat(conds)
-
- return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
-
-
-def subscript_cond(cond, a, b):
- if not isinstance(cond, dict):
- return cond[a:b]
-
- return {key: vec[a:b] for key, vec in cond.items()}
-
-
-def pad_cond(tensor, repeats, empty):
- if not isinstance(tensor, dict):
- return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
-
- tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
- return tensor
-
-
-class CFGDenoiser(torch.nn.Module):
- """
- Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
- that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
- instead of one. Originally, the second prompt is just an empty string, but we use non-empty
- negative prompt.
- """
-
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.step = 0
- self.image_cfg_scale = None
- self.padded_cond_uncond = False
-
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
-
- return denoised
-
- def combine_denoised_for_edit_model(self, x_out, cond_scale):
- out_cond, out_img_cond, out_uncond = x_out.chunk(3)
- denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
-
- return denoised
-
- def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
-
- # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
- # so is_edit_model is set to False to support AND composition.
- is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
-
- assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
-
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
-
- if shared.sd_model.model.conditioning_key == "crossattn-adm":
- image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
- else:
- image_uncond = image_cond
- if isinstance(uncond, dict):
- make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
- else:
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
-
- if not is_edit_model:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
- else:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
-
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
- tensor = denoiser_params.text_cond
- uncond = denoiser_params.text_uncond
- skip_uncond = False
-
- # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
- if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
- skip_uncond = True
- x_in = x_in[:-batch_size]
- sigma_in = sigma_in[:-batch_size]
-
- self.padded_cond_uncond = False
- if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
- empty = shared.sd_model.cond_stage_model_empty_prompt
- num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
-
- if num_repeats < 0:
- tensor = pad_cond(tensor, -num_repeats, empty)
- self.padded_cond_uncond = True
- elif num_repeats > 0:
- uncond = pad_cond(uncond, num_repeats, empty)
- self.padded_cond_uncond = True
-
- if tensor.shape[1] == uncond.shape[1] or skip_uncond:
- if is_edit_model:
- cond_in = catenate_conds([tensor, uncond, uncond])
- elif skip_uncond:
- cond_in = tensor
- else:
- cond_in = catenate_conds([tensor, uncond])
-
- if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
-
- if not is_edit_model:
- c_crossattn = subscript_cond(tensor, a, b)
- else:
- c_crossattn = torch.cat([tensor[a:b]], uncond)
-
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
-
- if not skip_uncond:
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
-
- denoised_image_indexes = [x[0][0] for x in conds_list]
- if skip_uncond:
- fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
- x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
-
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
- cfg_denoised_callback(denoised_params)
-
- devices.test_for_nans(x_out, "unet")
-
- if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
- elif opts.live_preview_content == "Negative prompt":
- sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
-
- if is_edit_model:
- denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
- elif skip_uncond:
- denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
- else:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
-
- if self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
-
- after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
- cfg_after_cfg_callback(after_cfg_callback_params)
- denoised = after_cfg_callback_params.x
-
- self.step += 1
- return denoised
-
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
+class KDiffusionSampler(sd_samplers_common.Sampler):
+ def __init__(self, funcname, sd_model):
- return devices.randn_like(x)
+ super().__init__(funcname)
+ self.extra_params = sampler_extra_params.get(funcname, [])
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
-class KDiffusionSampler:
- def __init__(self, funcname, sd_model):
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
-
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
- self.funcname = funcname
- self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
- self.extra_params = sampler_extra_params.get(funcname, [])
- self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
- self.sampler_noises = None
- self.stop_at = None
- self.eta = None
- self.config = None # set by the function calling the constructor
- self.last_latent = None
- self.s_min_uncond = None
-
- # NOTE: These are also defined in the StableDiffusionProcessing class.
- # They should have been here to begin with but we're going to
- # leave that class __init__ signature alone.
- self.s_churn = 0.0
- self.s_tmin = 0.0
- self.s_tmax = float('inf')
- self.s_noise = 1.0
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def callback_state(self, d):
- step = d['i']
- latent = d["denoised"]
- if opts.live_preview_content == "Combined":
- sd_samplers_common.store_latent(latent)
- self.last_latent = latent
-
- if self.stop_at is not None and step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- state.sampling_step = step
- shared.total_tqdm.update()
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except RecursionError:
- print(
- 'Encountered RecursionError during sampling, returning last latent. '
- 'rho >5 with a polyexponential scheduler may cause this error. '
- 'You should try to use a smaller rho value instead.'
- )
- return self.last_latent
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def number_of_needed_noises(self, p):
- return p.steps
-
- def initialize(self, p: StableDiffusionProcessing):
- self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
- self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.step = 0
- self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
- self.eta = p.eta if p.eta is not None else opts.eta_ancestral
- self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
-
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
-
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
-
- if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
- p.extra_generation_params["Eta"] = self.eta
-
- extra_params_kwargs['eta'] = self.eta
-
- if len(self.extra_params) > 0:
- s_churn = getattr(opts, 's_churn', p.s_churn)
- s_tmin = getattr(opts, 's_tmin', p.s_tmin)
- s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
- s_noise = getattr(opts, 's_noise', p.s_noise)
-
- if s_churn != self.s_churn:
- extra_params_kwargs['s_churn'] = s_churn
- p.s_churn = s_churn
- p.extra_generation_params['Sigma churn'] = s_churn
- if s_tmin != self.s_tmin:
- extra_params_kwargs['s_tmin'] = s_tmin
- p.s_tmin = s_tmin
- p.extra_generation_params['Sigma tmin'] = s_tmin
- if s_tmax != self.s_tmax:
- extra_params_kwargs['s_tmax'] = s_tmax
- p.s_tmax = s_tmax
- p.extra_generation_params['Sigma tmax'] = s_tmax
- if s_noise != self.s_noise:
- extra_params_kwargs['s_noise'] = s_noise
- p.s_noise = s_noise
- p.extra_generation_params['Sigma noise'] = s_noise
-
- return extra_params_kwargs
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap, self)
def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -419,22 +115,12 @@ class KDiffusionSampler:
return sigmas
- def create_noise_sampler(self, x, sigmas, p):
- """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
- if shared.opts.no_dpmpp_sde_batch_determinism:
- return None
-
- from k_diffusion.sampling import BrownianTreeNoiseSampler
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
- current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
- return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
-
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
-
sigma_sched = sigmas[steps - t_enc - 1:]
+
xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p)
@@ -483,12 +169,14 @@ class KDiffusionSampler:
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = steps
+
if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in parameters:
- extra_params_kwargs['n'] = steps
- else:
+
+ if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigmas
if self.config.options.get('brownian_noise', False):
@@ -509,3 +197,4 @@ class KDiffusionSampler:
return samples
+
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
new file mode 100644
index 00000000..f61799a8
--- /dev/null
+++ b/modules/sd_samplers_timesteps.py
@@ -0,0 +1,151 @@
+import torch
+import inspect
+import sys
+from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
+from modules.sd_samplers_cfg_denoiser import CFGDenoiser
+
+from modules.shared import opts
+import modules.shared as shared
+
+samplers_timesteps = [
+ ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
+ ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
+ ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
+]
+
+
+samplers_data_timesteps = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_timesteps
+]
+
+
+class CompVisTimestepsDenoiser(torch.nn.Module):
+ def __init__(self, model, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.inner_model = model
+
+ def forward(self, input, timesteps, **kwargs):
+ return self.inner_model.apply_model(input, timesteps, **kwargs)
+
+
+class CompVisTimestepsVDenoiser(torch.nn.Module):
+ def __init__(self, model, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.inner_model = model
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+
+ def forward(self, input, timesteps, **kwargs):
+ model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
+ e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
+ return e_t
+
+
+class CFGDenoiserTimesteps(CFGDenoiser):
+
+ def __init__(self, model, sampler):
+ super().__init__(model, sampler)
+
+ self.alphas = model.inner_model.alphas_cumprod
+
+ def get_pred_x0(self, x_in, x_out, sigma):
+ ts = int(sigma.item())
+
+ s_in = x_in.new_ones([x_in.shape[0]])
+ a_t = self.alphas[ts].item() * s_in
+ sqrt_one_minus_at = (1 - a_t).sqrt()
+
+ pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
+
+ return pred_x0
+
+
+class CompVisSampler(sd_samplers_common.Sampler):
+ def __init__(self, funcname, sd_model):
+ super().__init__(funcname)
+
+ self.eta_option_field = 'eta_ddim'
+ self.eta_infotext_field = 'Eta DDIM'
+
+ denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
+ self.model_wrap = denoiser(sd_model)
+ self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
+
+ def get_timesteps(self, p, steps):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
+ discard_next_to_last_sigma = True
+ p.extra_generation_params["Discard penultimate sigma"] = True
+
+ steps += 1 if discard_next_to_last_sigma else 0
+
+ timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
+
+ return timesteps
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+
+ timesteps = self.get_timesteps(p, steps)
+ timesteps_sched = timesteps[:t_enc]
+
+ alphas_cumprod = shared.sd_model.alphas_cumprod
+ sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
+ sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
+
+ xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'timesteps' in parameters:
+ extra_params_kwargs['timesteps'] = timesteps_sched
+ if 'is_img2img' in parameters:
+ extra_params_kwargs['is_img2img'] = True
+
+ self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
+ extra_args = {
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps = steps or p.steps
+ timesteps = self.get_timesteps(p, steps)
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'timesteps' in parameters:
+ extra_params_kwargs['timesteps'] = timesteps
+
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ return samples
+
+
+sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__]
+VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
new file mode 100644
index 00000000..48d7e649
--- /dev/null
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -0,0 +1,135 @@
+import torch
+import tqdm
+import k_diffusion.sampling
+import numpy as np
+
+from modules import shared
+from modules.models.diffusion.uni_pc import uni_pc
+
+
+@torch.no_grad()
+def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+ sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+
+ e_t = model(x, timesteps[index].item() * s_in, **extra_args)
+
+ a_t = alphas[index].item() * s_in
+ a_prev = alphas_prev[index].item() * s_in
+ sigma_t = sigmas[index].item() * s_in
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
+ noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
+ x = a_prev.sqrt() * pred_x0 + dir_xt + noise
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+@torch.no_grad()
+def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+ alphas = alphas_cumprod[timesteps]
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
+
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ old_eps = []
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = alphas[index].item() * s_in
+ a_prev = alphas_prev[index].item() * s_in
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev).sqrt() * e_t
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev, pred_x0
+
+ for i in tqdm.trange(len(timesteps) - 1, disable=disable):
+ index = len(timesteps) - 1 - i
+ ts = timesteps[index].item() * s_in
+ t_next = timesteps[max(index - 1, 0)].item() * s_in
+
+ e_t = model(x, ts, **extra_args)
+
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = model(x_prev, t_next, **extra_args)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ else:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+
+ x = x_prev
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
+
+ return x
+
+
+class UniPCCFG(uni_pc.UniPC):
+ def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
+ super().__init__(None, *args, **kwargs)
+
+ def after_update(x, model_x):
+ callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
+ self.index += 1
+
+ self.cfg_model = cfg_model
+ self.extra_args = extra_args
+ self.callback = callback
+ self.index = 0
+ self.after_update = after_update
+
+ def get_model_input_time(self, t_continuous):
+ return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
+
+ def model(self, x, t):
+ t_input = self.get_model_input_time(t)
+
+ res = self.cfg_model(x, t_input, **self.extra_args)
+
+ return res
+
+
+def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
+ alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
+
+ ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
+ unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
+ x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
+
+ return x
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 38bcb840..5ac1ac31 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -2,7 +2,8 @@ import os
import collections
from dataclasses import dataclass
-from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
+from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack
+
import glob
from copy import deepcopy
@@ -231,8 +232,6 @@ unspecified = object()
def reload_vae_weights(sd_model=None, vae_file=unspecified):
- from modules import lowvram, devices, sd_hijack
-
if not sd_model:
sd_model = shared.sd_model
diff --git a/modules/shared.py b/modules/shared.py
index 97f1eab5..d9d01484 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,843 +1,52 @@
-import datetime
-import json
-import os
-import re
import sys
-import threading
-import time
-import logging
import gradio as gr
-import torch
-import tqdm
-import launch
-import modules.interrogate
-import modules.memmon
-import modules.styles
-import modules.devices as devices
-from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
+from modules import shared_cmd_options, shared_gradio_themes, options, shared_items
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
-from typing import Optional
+from modules import util
-log = logging.getLogger(__name__)
-
-demo = None
-
-parser = cmd_args.parser
-
-script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
-script_loading.preload_extensions(extensions_builtin_dir, parser)
-
-if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
- cmd_opts = parser.parse_args()
-else:
- cmd_opts, _ = parser.parse_known_args()
-
-
-restricted_opts = {
- "samples_filename_pattern",
- "directories_filename_pattern",
- "outdir_samples",
- "outdir_txt2img_samples",
- "outdir_img2img_samples",
- "outdir_extras_samples",
- "outdir_grids",
- "outdir_txt2img_grids",
- "outdir_save",
- "outdir_init_images"
-}
-
-# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
-gradio_hf_hub_themes = [
- "gradio/base",
- "gradio/glass",
- "gradio/monochrome",
- "gradio/seafoam",
- "gradio/soft",
- "gradio/dracula_test",
- "abidlabs/dracula_test",
- "abidlabs/Lime",
- "abidlabs/pakistan",
- "Ama434/neutral-barlow",
- "dawood/microsoft_windows",
- "finlaymacklon/smooth_slate",
- "Franklisi/darkmode",
- "freddyaboulton/dracula_revamped",
- "freddyaboulton/test-blue",
- "gstaff/xkcd",
- "Insuz/Mocha",
- "Insuz/SimpleIndigo",
- "JohnSmith9982/small_and_pretty",
- "nota-ai/theme",
- "nuttea/Softblue",
- "ParityError/Anime",
- "reilnuud/polite",
- "remilia/Ghostly",
- "rottenlittlecreature/Moon_Goblin",
- "step-3-profit/Midnight-Deep",
- "Taithrah/Minimal",
- "ysharma/huggingface",
- "ysharma/steampunk"
-]
-
-
-cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
-
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
- (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
-
-devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
-devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
-
-device = devices.device
-weight_load_location = None if cmd_opts.lowram else "cpu"
+cmd_opts = shared_cmd_options.cmd_opts
+parser = shared_cmd_options.parser
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
-xformers_available = False
-config_filename = cmd_opts.ui_settings_file
-
-os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
-hypernetworks = {}
-loaded_hypernetworks = []
-
-
-def reload_hypernetworks():
- from modules.hypernetworks import hypernetwork
- global hypernetworks
-
- hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
-
-
-class State:
- skipped = False
- interrupted = False
- job = ""
- job_no = 0
- job_count = 0
- processing_has_refined_job_count = False
- job_timestamp = '0'
- sampling_step = 0
- sampling_steps = 0
- current_latent = None
- current_image = None
- current_image_sampling_step = 0
- id_live_preview = 0
- textinfo = None
- time_start = None
- server_start = None
- _server_command_signal = threading.Event()
- _server_command: Optional[str] = None
-
- @property
- def need_restart(self) -> bool:
- # Compatibility getter for need_restart.
- return self.server_command == "restart"
-
- @need_restart.setter
- def need_restart(self, value: bool) -> None:
- # Compatibility setter for need_restart.
- if value:
- self.server_command = "restart"
-
- @property
- def server_command(self):
- return self._server_command
-
- @server_command.setter
- def server_command(self, value: Optional[str]) -> None:
- """
- Set the server command to `value` and signal that it's been set.
- """
- self._server_command = value
- self._server_command_signal.set()
-
- def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
- """
- Wait for server command to get set; return and clear the value and signal.
- """
- if self._server_command_signal.wait(timeout):
- self._server_command_signal.clear()
- req = self._server_command
- self._server_command = None
- return req
- return None
-
- def request_restart(self) -> None:
- self.interrupt()
- self.server_command = "restart"
- log.info("Received restart request")
-
- def skip(self):
- self.skipped = True
- log.info("Received skip request")
-
- def interrupt(self):
- self.interrupted = True
- log.info("Received interrupt request")
-
- def nextjob(self):
- if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
- self.do_set_current_image()
-
- self.job_no += 1
- self.sampling_step = 0
- self.current_image_sampling_step = 0
-
- def dict(self):
- obj = {
- "skipped": self.skipped,
- "interrupted": self.interrupted,
- "job": self.job,
- "job_count": self.job_count,
- "job_timestamp": self.job_timestamp,
- "job_no": self.job_no,
- "sampling_step": self.sampling_step,
- "sampling_steps": self.sampling_steps,
- }
-
- return obj
-
- def begin(self, job: str = "(unknown)"):
- self.sampling_step = 0
- self.job_count = -1
- self.processing_has_refined_job_count = False
- self.job_no = 0
- self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
- self.current_latent = None
- self.current_image = None
- self.current_image_sampling_step = 0
- self.id_live_preview = 0
- self.skipped = False
- self.interrupted = False
- self.textinfo = None
- self.time_start = time.time()
- self.job = job
- devices.torch_gc()
- log.info("Starting job %s", job)
-
- def end(self):
- duration = time.time() - self.time_start
- log.info("Ending job %s (%.2f seconds)", self.job, duration)
- self.job = ""
- self.job_count = 0
-
- devices.torch_gc()
-
- def set_current_image(self):
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
- if not parallel_processing_allowed:
- return
-
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
- self.do_set_current_image()
-
- def do_set_current_image(self):
- if self.current_latent is None:
- return
-
- import modules.sd_samplers
-
- try:
- if opts.show_progress_grid:
- self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
- else:
- self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
-
- self.current_image_sampling_step = self.sampling_step
-
- except Exception:
- # when switching models during genration, VAE would be on CPU, so creating an image will fail.
- # we silently ignore this error
- errors.record_exception()
-
- def assign_current_image(self, image):
- self.current_image = image
- self.id_live_preview += 1
-
-
-state = State()
-state.server_start = time.time()
-
styles_filename = cmd_opts.styles_file
-prompt_styles = modules.styles.StyleDatabase(styles_filename)
-
-interrogator = modules.interrogate.InterrogateModels("interrogate")
-
-face_restorers = []
-
-
-class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
- self.default = default
- self.label = label
- self.component = component
- self.component_args = component_args
- self.onchange = onchange
- self.section = section
- self.refresh = refresh
- self.do_not_save = False
-
- self.comment_before = comment_before
- """HTML text that will be added after label in UI"""
-
- self.comment_after = comment_after
- """HTML text that will be added before label in UI"""
-
- def link(self, label, url):
- self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
- return self
-
- def js(self, label, js_func):
- self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
- return self
-
- def info(self, info):
- self.comment_after += f"<span class='info'>({info})</span>"
- return self
-
- def html(self, html):
- self.comment_after += html
- return self
-
- def needs_restart(self):
- self.comment_after += " <span class='info'>(requires restart)</span>"
- return self
-
- def needs_reload_ui(self):
- self.comment_after += " <span class='info'>(requires Reload UI)</span>"
- return self
-
-
-class OptionHTML(OptionInfo):
- def __init__(self, text):
- super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
-
- self.do_not_save = True
-
-
-def options_section(section_identifier, options_dict):
- for v in options_dict.values():
- v.section = section_identifier
-
- return options_dict
-
-
-def list_checkpoint_tiles():
- import modules.sd_models
- return modules.sd_models.checkpoint_tiles()
-
-
-def refresh_checkpoints():
- import modules.sd_models
- return modules.sd_models.list_models()
-
-
-def list_samplers():
- import modules.sd_samplers
- return modules.sd_samplers.all_samplers
-
-
+config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
-tab_names = []
-
-options_templates = {}
-
-options_templates.update(options_section(('saving-images', "Saving images/grids"), {
- "samples_save": OptionInfo(True, "Always save all generated images"),
- "samples_format": OptionInfo('png', 'File format for images'),
- "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
- "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
-
- "grid_save": OptionInfo(True, "Always save all generated image grids"),
- "grid_format": OptionInfo('png', 'File format for grids'),
- "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
- "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
- "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
- "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
- "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
- "font": OptionInfo("", "Font for image grids that have text"),
- "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
- "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
- "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
-
- "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
- "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
- "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
- "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
- "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
- "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
- "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
- "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
- "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
- "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
- "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
- "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
- "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
-
- "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
- "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
- "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
- "save_init_img": OptionInfo(False, "Save init images when using img2img"),
-
- "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
- "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
-
- "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
-}))
-
-options_templates.update(options_section(('saving-paths', "Paths for saving"), {
- "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
- "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
- "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
- "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
- "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
- "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
- "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
- "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
- "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
-}))
-
-options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
- "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
- "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
- "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
- "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
-}))
-
-options_templates.update(options_section(('upscaling', "Upscaling"), {
- "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
- "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
- "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
- "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
-}))
-
-options_templates.update(options_section(('face-restoration', "Face restoration"), {
- "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
- "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
- "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
-}))
-
-options_templates.update(options_section(('system', "System"), {
- "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
- "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
- "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
- "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
- "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
- "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
- "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
- "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
- "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
- "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
-}))
-
-options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
- "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
- "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
- "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
- "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
- "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
- "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
- "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
- "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
- "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
- "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
-}))
-
-options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
- "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
- "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
- "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
- "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
- "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
- "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
- "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
- "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
- "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
- "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
-}))
-
-options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
- "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
- "sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
- "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
- "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
-}))
-options_templates.update(options_section(('vae', "VAE"), {
- "sd_vae_explanation": OptionHTML("""
-<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
-image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
-(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
-For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
-"""),
- "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
- "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
- "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
- "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
- "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
-}))
-
-options_templates.update(options_section(('img2img', "img2img"), {
- "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
- "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
- "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
- "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
- "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
- "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
- "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
- "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
- "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
-}))
-
-options_templates.update(options_section(('optimizations', "Optimizations"), {
- "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
- "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
- "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
- "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
- "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
- "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
- "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
-}))
-
-options_templates.update(options_section(('compatibility', "Compatibility"), {
- "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
- "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
- "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
- "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
- "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
- "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
-}))
-
-options_templates.update(options_section(('interrogate', "Interrogate"), {
- "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
- "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
- "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
- "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
- "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
- "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
- "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
- "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
- "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
- "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
-}))
-
-options_templates.update(options_section(('extra_networks', "Extra Networks"), {
- "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
- "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
- "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
- "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
- "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
- "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
- "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
- "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
- "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
- "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
- "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
-}))
-
-options_templates.update(options_section(('ui', "User interface"), {
- "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
- "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
- "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
- "return_grid": OptionInfo(True, "Show grid in results for web"),
- "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
- "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
- "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
- "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
- "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
- "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
- "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
- "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
- "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
- "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
- "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
- "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
- "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
- "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
- "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
- "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
- "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
- "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
-}))
-
-
-options_templates.update(options_section(('infotext', "Infotext"), {
- "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
- "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
- "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
- "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
- "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
-<li>Ignore: keep prompt and styles dropdown as it is.</li>
-<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
-<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
-<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
-</ul>"""),
-
-}))
-
-options_templates.update(options_section(('ui', "Live previews"), {
- "show_progressbar": OptionInfo(True, "Show progressbar"),
- "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
- "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
- "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
- "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
- "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
- "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
- "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
-}))
-
-options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_reload_ui(),
- "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
- "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
- "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
- 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
- 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
- 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
- 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
- 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
- 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
- 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
- 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
- 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
- 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
- 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
- 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
- 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
- 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
-}))
-
-options_templates.update(options_section(('postprocessing', "Postprocessing"), {
- 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
- 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
- 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
-}))
-
-options_templates.update(options_section((None, "Hidden options"), {
- "disabled_extensions": OptionInfo([], "Disable these extensions"),
- "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
- "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
- "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
-}))
-
-
-options_templates.update()
-
-
-class Options:
- data = None
- data_labels = options_templates
- typemap = {int: float}
-
- def __init__(self):
- self.data = {k: v.default for k, v in self.data_labels.items()}
-
- def __setattr__(self, key, value):
- if self.data is not None:
- if key in self.data or key in self.data_labels:
- assert not cmd_opts.freeze_settings, "changing settings is disabled"
-
- info = opts.data_labels.get(key, None)
- if info.do_not_save:
- return
-
- comp_args = info.component_args if info else None
- if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
- raise RuntimeError(f"not possible to set {key} because it is restricted")
-
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
- raise RuntimeError(f"not possible to set {key} because it is restricted")
-
- self.data[key] = value
- return
-
- return super(Options, self).__setattr__(key, value)
-
- def __getattr__(self, item):
- if self.data is not None:
- if item in self.data:
- return self.data[item]
-
- if item in self.data_labels:
- return self.data_labels[item].default
-
- return super(Options, self).__getattribute__(item)
-
- def set(self, key, value):
- """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
-
- oldval = self.data.get(key, None)
- if oldval == value:
- return False
-
- if self.data_labels[key].do_not_save:
- return False
-
- try:
- setattr(self, key, value)
- except RuntimeError:
- return False
-
- if self.data_labels[key].onchange is not None:
- try:
- self.data_labels[key].onchange()
- except Exception as e:
- errors.display(e, f"changing setting {key} to {value}")
- setattr(self, key, oldval)
- return False
-
- return True
-
- def get_default(self, key):
- """returns the default value for the key"""
-
- data_label = self.data_labels.get(key)
- if data_label is None:
- return None
-
- return data_label.default
-
- def save(self, filename):
- assert not cmd_opts.freeze_settings, "saving settings is disabled"
-
- with open(filename, "w", encoding="utf8") as file:
- json.dump(self.data, file, indent=4)
-
- def same_type(self, x, y):
- if x is None or y is None:
- return True
-
- type_x = self.typemap.get(type(x), type(x))
- type_y = self.typemap.get(type(y), type(y))
-
- return type_x == type_y
-
- def load(self, filename):
- with open(filename, "r", encoding="utf8") as file:
- self.data = json.load(file)
-
- # 1.6.0 VAE defaults
- if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
- self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
-
- # 1.1.1 quicksettings list migration
- if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
- self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
-
- # 1.4.0 ui_reorder
- if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
- self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
-
- bad_settings = 0
- for k, v in self.data.items():
- info = self.data_labels.get(k, None)
- if info is not None and not self.same_type(info.default, v):
- print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
- bad_settings += 1
-
- if bad_settings > 0:
- print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
-
- def onchange(self, key, func, call=True):
- item = self.data_labels.get(key)
- item.onchange = func
-
- if call:
- func()
-
- def dumpjson(self):
- d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
- d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
- d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
- return json.dumps(d)
-
- def add_option(self, key, info):
- self.data_labels[key] = info
-
- def reorder(self):
- """reorder settings so that all items related to section always go together"""
-
- section_ids = {}
- settings_items = self.data_labels.items()
- for _, item in settings_items:
- if item.section not in section_ids:
- section_ids[item.section] = len(section_ids)
-
- self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
-
- def cast_value(self, key, value):
- """casts an arbitrary to the same type as this setting's value with key
- Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
- """
-
- if value is None:
- return None
-
- default_value = self.data_labels[key].default
- if default_value is None:
- default_value = getattr(self, key, None)
- if default_value is None:
- return None
-
- expected_type = type(default_value)
- if expected_type == bool and value == "False":
- value = False
- else:
- value = expected_type(value)
-
- return value
+demo = None
+device = None
-opts = Options()
-if os.path.exists(config_filename):
- opts.load(config_filename)
+weight_load_location = None
+xformers_available = False
-class Shared(sys.modules[__name__].__class__):
- """
- this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
- at program startup.
- """
+hypernetworks = {}
- sd_model_val = None
+loaded_hypernetworks = []
- @property
- def sd_model(self):
- import modules.sd_models
+state = None
- return modules.sd_models.model_data.get_sd_model()
+prompt_styles = None
- @sd_model.setter
- def sd_model(self, value):
- import modules.sd_models
+interrogator = None
- modules.sd_models.model_data.set_sd_model(value)
+face_restorers = []
+options_templates = None
+opts = None
+restricted_opts = None
-sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
-sys.modules[__name__].__class__ = Shared
+sd_model: LatentDiffusion = None
settings_components = None
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
+tab_names = []
+
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False},
@@ -856,121 +65,24 @@ progress_print_out = sys.stdout
gradio_theme = gr.themes.Base()
+total_tqdm = None
-def reload_gradio_theme(theme_name=None):
- global gradio_theme
- if not theme_name:
- theme_name = opts.gradio_theme
-
- default_theme_args = dict(
- font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
- font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
- )
-
- if theme_name == "Default":
- gradio_theme = gr.themes.Default(**default_theme_args)
- else:
- try:
- theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
- theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
- if opts.gradio_themes_cache and os.path.exists(theme_cache_path):
- gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
- else:
- os.makedirs(theme_cache_dir, exist_ok=True)
- gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
- gradio_theme.dump(theme_cache_path)
- except Exception as e:
- errors.display(e, "changing gradio theme")
- gradio_theme = gr.themes.Default(**default_theme_args)
-
-
-class TotalTQDM:
- def __init__(self):
- self._tqdm = None
-
- def reset(self):
- self._tqdm = tqdm.tqdm(
- desc="Total progress",
- total=state.job_count * state.sampling_steps,
- position=1,
- file=progress_print_out
- )
-
- def update(self):
- if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
- return
- if self._tqdm is None:
- self.reset()
- self._tqdm.update()
-
- def updateTotal(self, new_total):
- if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
- return
- if self._tqdm is None:
- self.reset()
- self._tqdm.total = new_total
-
- def clear(self):
- if self._tqdm is not None:
- self._tqdm.refresh()
- self._tqdm.close()
- self._tqdm = None
-
-
-total_tqdm = TotalTQDM()
-
-mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
-mem_mon.start()
-
-
-def natural_sort_key(s, regex=re.compile('([0-9]+)')):
- return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
-
-
-def listfiles(dirname):
- filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
- return [file for file in filenames if os.path.isfile(file)]
-
-
-def html_path(filename):
- return os.path.join(script_path, "html", filename)
-
-
-def html(filename):
- path = html_path(filename)
-
- if os.path.exists(path):
- with open(path, encoding="utf8") as file:
- return file.read()
-
- return ""
-
-
-def walk_files(path, allowed_extensions=None):
- if not os.path.exists(path):
- return
-
- if allowed_extensions is not None:
- allowed_extensions = set(allowed_extensions)
-
- items = list(os.walk(path, followlinks=True))
- items = sorted(items, key=lambda x: natural_sort_key(x[0]))
-
- for root, _, files in items:
- for filename in sorted(files, key=natural_sort_key):
- if allowed_extensions is not None:
- _, ext = os.path.splitext(filename)
- if ext not in allowed_extensions:
- continue
-
- if not opts.list_hidden_files and ("/." in root or "\\." in root):
- continue
+mem_mon = None
- yield os.path.join(root, filename)
+options_section = options.options_section
+OptionInfo = options.OptionInfo
+OptionHTML = options.OptionHTML
+natural_sort_key = util.natural_sort_key
+listfiles = util.listfiles
+html_path = util.html_path
+html = util.html
+walk_files = util.walk_files
+ldm_print = util.ldm_print
-def ldm_print(*args, **kwargs):
- if opts.hide_ldm_prints:
- return
+reload_gradio_theme = shared_gradio_themes.reload_gradio_theme
- print(*args, **kwargs)
+list_checkpoint_tiles = shared_items.list_checkpoint_tiles
+refresh_checkpoints = shared_items.refresh_checkpoints
+list_samplers = shared_items.list_samplers
+reload_hypernetworks = shared_items.reload_hypernetworks
diff --git a/modules/shared_cmd_options.py b/modules/shared_cmd_options.py
new file mode 100644
index 00000000..af24938b
--- /dev/null
+++ b/modules/shared_cmd_options.py
@@ -0,0 +1,18 @@
+import os
+
+import launch
+from modules import cmd_args, script_loading
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
+
+parser = cmd_args.parser
+
+script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
+script_loading.preload_extensions(extensions_builtin_dir, parser)
+
+if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
+ cmd_opts = parser.parse_args()
+else:
+ cmd_opts, _ = parser.parse_known_args()
+
+
+cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py
new file mode 100644
index 00000000..ad1f2212
--- /dev/null
+++ b/modules/shared_gradio_themes.py
@@ -0,0 +1,66 @@
+import os
+
+import gradio as gr
+
+from modules import errors, shared
+from modules.paths_internal import script_path
+
+
+# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
+gradio_hf_hub_themes = [
+ "gradio/base",
+ "gradio/glass",
+ "gradio/monochrome",
+ "gradio/seafoam",
+ "gradio/soft",
+ "gradio/dracula_test",
+ "abidlabs/dracula_test",
+ "abidlabs/Lime",
+ "abidlabs/pakistan",
+ "Ama434/neutral-barlow",
+ "dawood/microsoft_windows",
+ "finlaymacklon/smooth_slate",
+ "Franklisi/darkmode",
+ "freddyaboulton/dracula_revamped",
+ "freddyaboulton/test-blue",
+ "gstaff/xkcd",
+ "Insuz/Mocha",
+ "Insuz/SimpleIndigo",
+ "JohnSmith9982/small_and_pretty",
+ "nota-ai/theme",
+ "nuttea/Softblue",
+ "ParityError/Anime",
+ "reilnuud/polite",
+ "remilia/Ghostly",
+ "rottenlittlecreature/Moon_Goblin",
+ "step-3-profit/Midnight-Deep",
+ "Taithrah/Minimal",
+ "ysharma/huggingface",
+ "ysharma/steampunk"
+]
+
+
+def reload_gradio_theme(theme_name=None):
+ if not theme_name:
+ theme_name = shared.opts.gradio_theme
+
+ default_theme_args = dict(
+ font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
+ font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
+ )
+
+ if theme_name == "Default":
+ shared.gradio_theme = gr.themes.Default(**default_theme_args)
+ else:
+ try:
+ theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
+ theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
+ if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):
+ shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
+ else:
+ os.makedirs(theme_cache_dir, exist_ok=True)
+ gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
+ gradio_theme.dump(theme_cache_path)
+ except Exception as e:
+ errors.display(e, "changing gradio theme")
+ shared.gradio_theme = gr.themes.Default(**default_theme_args)
diff --git a/modules/shared_init.py b/modules/shared_init.py
new file mode 100644
index 00000000..d3fb687e
--- /dev/null
+++ b/modules/shared_init.py
@@ -0,0 +1,49 @@
+import os
+
+import torch
+
+from modules import shared
+from modules.shared import cmd_opts
+
+
+def initialize():
+ """Initializes fields inside the shared module in a controlled manner.
+
+ Should be called early because some other modules you can import mingt need these fields to be already set.
+ """
+
+ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
+
+ from modules import options, shared_options
+ shared.options_templates = shared_options.options_templates
+ shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
+ shared.restricted_opts = shared_options.restricted_opts
+ if os.path.exists(shared.config_filename):
+ shared.opts.load(shared.config_filename)
+
+ from modules import devices
+ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
+ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
+
+ devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
+
+ shared.device = devices.device
+ shared.weight_load_location = None if cmd_opts.lowram else "cpu"
+
+ from modules import shared_state
+ shared.state = shared_state.State()
+
+ from modules import styles
+ shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
+
+ from modules import interrogate
+ shared.interrogator = interrogate.InterrogateModels("interrogate")
+
+ from modules import shared_total_tqdm
+ shared.total_tqdm = shared_total_tqdm.TotalTQDM()
+
+ from modules import memmon, devices
+ shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
+ shared.mem_mon.start()
+
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 89792e88..e4ec40a8 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -1,3 +1,6 @@
+import sys
+
+from modules.shared_cmd_options import cmd_opts
def realesrgan_models_names():
@@ -41,6 +44,28 @@ def refresh_unet_list():
modules.sd_unet.list_unets()
+def list_checkpoint_tiles():
+ import modules.sd_models
+ return modules.sd_models.checkpoint_tiles()
+
+
+def refresh_checkpoints():
+ import modules.sd_models
+ return modules.sd_models.list_models()
+
+
+def list_samplers():
+ import modules.sd_samplers
+ return modules.sd_samplers.all_samplers
+
+
+def reload_hypernetworks():
+ from modules.hypernetworks import hypernetwork
+ from modules import shared
+
+ shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+
+
ui_reorder_categories_builtin_items = [
"inpaint",
"sampler",
@@ -67,3 +92,27 @@ def ui_reorder_categories():
yield from sections
yield "scripts"
+
+
+class Shared(sys.modules[__name__].__class__):
+ """
+ this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
+ at program startup.
+ """
+
+ sd_model_val = None
+
+ @property
+ def sd_model(self):
+ import modules.sd_models
+
+ return modules.sd_models.model_data.get_sd_model()
+
+ @sd_model.setter
+ def sd_model(self, value):
+ import modules.sd_models
+
+ modules.sd_models.model_data.set_sd_model(value)
+
+
+sys.modules['modules.shared'].__class__ = Shared
diff --git a/modules/shared_options.py b/modules/shared_options.py
new file mode 100644
index 00000000..7468bc81
--- /dev/null
+++ b/modules/shared_options.py
@@ -0,0 +1,316 @@
+import gradio as gr
+
+from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
+from modules.shared_cmd_options import cmd_opts
+from modules.options import options_section, OptionInfo, OptionHTML
+
+options_templates = {}
+hide_dirs = shared.hide_dirs
+
+restricted_opts = {
+ "samples_filename_pattern",
+ "directories_filename_pattern",
+ "outdir_samples",
+ "outdir_txt2img_samples",
+ "outdir_img2img_samples",
+ "outdir_extras_samples",
+ "outdir_grids",
+ "outdir_txt2img_grids",
+ "outdir_save",
+ "outdir_init_images"
+}
+
+options_templates.update(options_section(('saving-images', "Saving images/grids"), {
+ "samples_save": OptionInfo(True, "Always save all generated images"),
+ "samples_format": OptionInfo('png', 'File format for images'),
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
+ "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
+
+ "grid_save": OptionInfo(True, "Always save all generated image grids"),
+ "grid_format": OptionInfo('png', 'File format for grids'),
+ "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
+ "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
+ "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
+ "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
+ "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+ "font": OptionInfo("", "Font for image grids that have text"),
+ "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
+
+ "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
+ "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
+ "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
+ "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
+ "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
+ "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
+ "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
+ "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
+ "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
+ "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
+ "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
+ "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
+
+ "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
+ "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
+ "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
+ "save_init_img": OptionInfo(False, "Save init images when using img2img"),
+
+ "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
+ "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
+
+ "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
+}))
+
+options_templates.update(options_section(('saving-paths', "Paths for saving"), {
+ "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
+ "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
+ "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
+ "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
+ "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
+ "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
+ "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
+ "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
+ "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
+}))
+
+options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
+ "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
+ "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
+ "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
+ "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
+ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
+}))
+
+options_templates.update(options_section(('upscaling', "Upscaling"), {
+ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
+ "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
+ "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
+}))
+
+options_templates.update(options_section(('face-restoration', "Face restoration"), {
+ "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
+ "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
+ "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
+}))
+
+options_templates.update(options_section(('system', "System"), {
+ "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
+ "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
+ "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
+ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
+ "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
+ "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
+ "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
+ "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
+ "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
+}))
+
+options_templates.update(options_section(('training', "Training"), {
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+ "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
+ "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+ "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+ "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
+ "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
+ "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
+ "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
+ "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
+}))
+
+options_templates.update(options_section(('sd', "Stable Diffusion"), {
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints),
+ "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+ "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
+ "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
+ "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
+ "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
+ "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
+ "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
+}))
+
+options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
+ "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
+ "sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
+ "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
+ "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
+}))
+
+options_templates.update(options_section(('vae', "VAE"), {
+ "sd_vae_explanation": OptionHTML("""
+<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
+image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
+(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
+For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
+"""),
+ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
+ "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
+ "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
+ "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
+ "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
+}))
+
+options_templates.update(options_section(('img2img', "img2img"), {
+ "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
+ "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
+ "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
+ "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
+ "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
+ "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
+ "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
+ "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
+ "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
+ "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
+}))
+
+options_templates.update(options_section(('optimizations', "Optimizations"), {
+ "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
+ "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
+ "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
+}))
+
+options_templates.update(options_section(('compatibility', "Compatibility"), {
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
+ "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+ "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
+ "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
+ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
+ "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
+}))
+
+options_templates.update(options_section(('interrogate', "Interrogate"), {
+ "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
+ "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
+ "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
+ "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
+ "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
+ "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": interrogate.category_types()}, refresh=interrogate.category_types),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
+ "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
+ "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
+ "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
+}))
+
+options_templates.update(options_section(('extra_networks', "Extra Networks"), {
+ "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
+ "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
+ "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
+ "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
+ "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
+ "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
+ "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
+ "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
+ "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
+}))
+
+options_templates.update(options_section(('ui', "User interface"), {
+ "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
+ "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
+ "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
+ "return_grid": OptionInfo(True, "Show grid in results for web"),
+ "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
+ "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
+ "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
+ "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
+ "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
+ "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
+ "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
+ "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
+ "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
+ "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
+ "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
+ "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
+ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
+ "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
+}))
+
+
+options_templates.update(options_section(('infotext', "Infotext"), {
+ "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+ "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
+ "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
+ "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
+ "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
+<li>Ignore: keep prompt and styles dropdown as it is.</li>
+<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
+<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
+<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
+</ul>"""),
+
+}))
+
+options_templates.update(options_section(('ui', "Live previews"), {
+ "show_progressbar": OptionInfo(True, "Show progressbar"),
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
+ "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
+ "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
+ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
+ "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
+}))
+
+options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
+ "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
+ "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
+ 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
+ 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
+ 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
+ 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
+ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
+ 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
+ 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
+}))
+
+options_templates.update(options_section(('postprocessing', "Postprocessing"), {
+ 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
+ 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
+ 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+}))
+
+options_templates.update(options_section((None, "Hidden options"), {
+ "disabled_extensions": OptionInfo([], "Disable these extensions"),
+ "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
+ "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
+}))
+
diff --git a/modules/shared_state.py b/modules/shared_state.py
new file mode 100644
index 00000000..3dc9c788
--- /dev/null
+++ b/modules/shared_state.py
@@ -0,0 +1,159 @@
+import datetime
+import logging
+import threading
+import time
+
+from modules import errors, shared, devices
+from typing import Optional
+
+log = logging.getLogger(__name__)
+
+
+class State:
+ skipped = False
+ interrupted = False
+ job = ""
+ job_no = 0
+ job_count = 0
+ processing_has_refined_job_count = False
+ job_timestamp = '0'
+ sampling_step = 0
+ sampling_steps = 0
+ current_latent = None
+ current_image = None
+ current_image_sampling_step = 0
+ id_live_preview = 0
+ textinfo = None
+ time_start = None
+ server_start = None
+ _server_command_signal = threading.Event()
+ _server_command: Optional[str] = None
+
+ def __init__(self):
+ self.server_start = time.time()
+
+ @property
+ def need_restart(self) -> bool:
+ # Compatibility getter for need_restart.
+ return self.server_command == "restart"
+
+ @need_restart.setter
+ def need_restart(self, value: bool) -> None:
+ # Compatibility setter for need_restart.
+ if value:
+ self.server_command = "restart"
+
+ @property
+ def server_command(self):
+ return self._server_command
+
+ @server_command.setter
+ def server_command(self, value: Optional[str]) -> None:
+ """
+ Set the server command to `value` and signal that it's been set.
+ """
+ self._server_command = value
+ self._server_command_signal.set()
+
+ def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
+ """
+ Wait for server command to get set; return and clear the value and signal.
+ """
+ if self._server_command_signal.wait(timeout):
+ self._server_command_signal.clear()
+ req = self._server_command
+ self._server_command = None
+ return req
+ return None
+
+ def request_restart(self) -> None:
+ self.interrupt()
+ self.server_command = "restart"
+ log.info("Received restart request")
+
+ def skip(self):
+ self.skipped = True
+ log.info("Received skip request")
+
+ def interrupt(self):
+ self.interrupted = True
+ log.info("Received interrupt request")
+
+ def nextjob(self):
+ if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
+ self.do_set_current_image()
+
+ self.job_no += 1
+ self.sampling_step = 0
+ self.current_image_sampling_step = 0
+
+ def dict(self):
+ obj = {
+ "skipped": self.skipped,
+ "interrupted": self.interrupted,
+ "job": self.job,
+ "job_count": self.job_count,
+ "job_timestamp": self.job_timestamp,
+ "job_no": self.job_no,
+ "sampling_step": self.sampling_step,
+ "sampling_steps": self.sampling_steps,
+ }
+
+ return obj
+
+ def begin(self, job: str = "(unknown)"):
+ self.sampling_step = 0
+ self.job_count = -1
+ self.processing_has_refined_job_count = False
+ self.job_no = 0
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ self.current_latent = None
+ self.current_image = None
+ self.current_image_sampling_step = 0
+ self.id_live_preview = 0
+ self.skipped = False
+ self.interrupted = False
+ self.textinfo = None
+ self.time_start = time.time()
+ self.job = job
+ devices.torch_gc()
+ log.info("Starting job %s", job)
+
+ def end(self):
+ duration = time.time() - self.time_start
+ log.info("Ending job %s (%.2f seconds)", self.job, duration)
+ self.job = ""
+ self.job_count = 0
+
+ devices.torch_gc()
+
+ def set_current_image(self):
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+ if not shared.parallel_processing_allowed:
+ return
+
+ if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
+ self.do_set_current_image()
+
+ def do_set_current_image(self):
+ if self.current_latent is None:
+ return
+
+ import modules.sd_samplers
+
+ try:
+ if shared.opts.show_progress_grid:
+ self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
+ else:
+ self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
+
+ self.current_image_sampling_step = self.sampling_step
+
+ except Exception:
+ # when switching models during genration, VAE would be on CPU, so creating an image will fail.
+ # we silently ignore this error
+ errors.record_exception()
+
+ def assign_current_image(self, image):
+ self.current_image = image
+ self.id_live_preview += 1
diff --git a/modules/shared_total_tqdm.py b/modules/shared_total_tqdm.py
new file mode 100644
index 00000000..cf82e104
--- /dev/null
+++ b/modules/shared_total_tqdm.py
@@ -0,0 +1,37 @@
+import tqdm
+
+from modules import shared
+
+
+class TotalTQDM:
+ def __init__(self):
+ self._tqdm = None
+
+ def reset(self):
+ self._tqdm = tqdm.tqdm(
+ desc="Total progress",
+ total=shared.state.job_count * shared.state.sampling_steps,
+ position=1,
+ file=shared.progress_print_out
+ )
+
+ def update(self):
+ if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
+ return
+ if self._tqdm is None:
+ self.reset()
+ self._tqdm.update()
+
+ def updateTotal(self, new_total):
+ if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
+ return
+ if self._tqdm is None:
+ self.reset()
+ self._tqdm.total = new_total
+
+ def clear(self):
+ if self._tqdm is not None:
+ self._tqdm.refresh()
+ self._tqdm.close()
+ self._tqdm = None
+
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
index cf24c6dd..7d906e1f 100644
--- a/modules/sysinfo.py
+++ b/modules/sysinfo.py
@@ -10,7 +10,7 @@ import psutil
import re
import launch
-from modules import paths_internal, timer
+from modules import paths_internal, timer, shared, extensions, errors
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
environment_whitelist = {
@@ -115,8 +115,6 @@ def format_exception(e, tb):
def get_exceptions():
try:
- from modules import errors
-
return list(reversed(errors.exception_records))
except Exception as e:
return str(e)
@@ -142,8 +140,6 @@ def get_torch_sysinfo():
def get_extensions(*, enabled):
try:
- from modules import extensions
-
def to_json(x: extensions.Extension):
return {
"name": x.name,
@@ -160,7 +156,6 @@ def get_extensions(*, enabled):
def get_config():
try:
- from modules import shared
return shared.opts.data
except Exception as e:
return str(e)
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 935ed418..edad8930 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,7 +1,7 @@
from contextlib import closing
import modules.scripts
-from modules import sd_samplers, processing
+from modules import processing
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts
import modules.shared as shared
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -25,7 +25,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
- sampler_name=sd_samplers.samplers[sampler_index].name,
+ sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
@@ -42,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
- hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
+ hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
diff --git a/modules/ui.py b/modules/ui.py
index 5150dae4..30b80417 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, devices, ui_extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -29,7 +29,6 @@ import modules.shared as shared
import modules.images
from modules import prompt_parser
from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
from modules.generation_parameters_copypaste import image_from_url_text
create_setting_component = ui_settings.create_setting_component
@@ -92,8 +91,6 @@ def send_gradio_gallery_to_image(x):
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
- from modules import processing, devices
-
if not enable:
return ""
@@ -360,14 +357,14 @@ def create_output_panel(tabname, outdir):
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
- sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
- return steps, sampler_index
+ return steps, sampler_name
def ordered_ui_categories():
@@ -414,7 +411,7 @@ def create_ui():
for category in ordered_ui_categories():
if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+ steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
elif category == "dimensions":
with FormRow():
@@ -460,7 +457,7 @@ def create_ui():
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
- hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
+ hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
@@ -520,7 +517,7 @@ def create_ui():
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
- sampler_index,
+ sampler_name,
restore_faces,
tiling,
batch_count,
@@ -538,7 +535,7 @@ def create_ui():
hr_resize_x,
hr_resize_y,
hr_checkpoint_name,
- hr_sampler_index,
+ hr_sampler_name,
hr_prompt,
hr_negative_prompt,
override_settings,
@@ -583,7 +580,7 @@ def create_ui():
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
- (sampler_index, "Sampler"),
+ (sampler_name, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
(seed, "Seed"),
@@ -605,7 +602,7 @@ def create_ui():
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
(hr_checkpoint_name, "Hires checkpoint"),
- (hr_sampler_index, "Hires sampler"),
+ (hr_sampler_name, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
@@ -621,7 +618,7 @@ def create_ui():
toprow.prompt,
toprow.negative_prompt,
steps,
- sampler_index,
+ sampler_name,
cfg_scale,
seed,
width,
@@ -631,7 +628,6 @@ def create_ui():
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
- from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
@@ -744,7 +740,7 @@ def create_ui():
for category in ordered_ui_categories():
if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
+ steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
elif category == "dimensions":
with FormRow():
@@ -876,7 +872,7 @@ def create_ui():
init_img_inpaint,
init_mask_inpaint,
steps,
- sampler_index,
+ sampler_name,
mask_blur,
mask_alpha,
inpainting_fill,
@@ -972,7 +968,7 @@ def create_ui():
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
- (sampler_index, "Sampler"),
+ (sampler_name, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"),
@@ -996,7 +992,6 @@ def create_ui():
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
))
- from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 303af9cd..99d19ff0 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -11,7 +11,7 @@ from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text
import modules.images
from modules.ui_components import ToolButton
-
+import modules.generation_parameters_copypaste as parameters_copypaste
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@@ -105,8 +105,6 @@ def save_files(js_data, images, do_make_zip, index):
def create_output_panel(tabname, outdir):
- from modules import shared
- import modules.generation_parameters_copypaste as parameters_copypaste
def open_folder(f):
if not os.path.exists(f):
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index e0b932b9..16d76a45 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -4,7 +4,6 @@ from pathlib import Path
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules.images import read_info_from_image, save_image_with_geninfo
-from modules.ui import up_down_symbol
import gradio as gr
import json
import html
@@ -348,6 +347,8 @@ def pages_in_preferred_order(pages):
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
+ from modules.ui import up_down_symbol
+
ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
index 1cb9eb6f..a5423fd8 100644
--- a/modules/ui_extra_networks_user_metadata.py
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -36,8 +36,8 @@ class UserMetadataEditor:
item = self.page.items.get(name, {})
user_metadata = item.get('user_metadata', None)
- if user_metadata is None:
- user_metadata = {}
+ if not user_metadata:
+ user_metadata = {'description': item.get('description', '')}
item['user_metadata'] = user_metadata
return user_metadata
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index fb75137e..506017e5 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -57,8 +57,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
return file_obj.name
-# override save to file function so that it also writes PNG info
-gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
+def install_ui_tempdir_override():
+ """override save to file function so that it also writes PNG info"""
+ gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
def on_tmpdir_changed():
diff --git a/modules/util.py b/modules/util.py
new file mode 100644
index 00000000..60afc067
--- /dev/null
+++ b/modules/util.py
@@ -0,0 +1,58 @@
+import os
+import re
+
+from modules import shared
+from modules.paths_internal import script_path
+
+
+def natural_sort_key(s, regex=re.compile('([0-9]+)')):
+ return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
+
+
+def listfiles(dirname):
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
+ return [file for file in filenames if os.path.isfile(file)]
+
+
+def html_path(filename):
+ return os.path.join(script_path, "html", filename)
+
+
+def html(filename):
+ path = html_path(filename)
+
+ if os.path.exists(path):
+ with open(path, encoding="utf8") as file:
+ return file.read()
+
+ return ""
+
+
+def walk_files(path, allowed_extensions=None):
+ if not os.path.exists(path):
+ return
+
+ if allowed_extensions is not None:
+ allowed_extensions = set(allowed_extensions)
+
+ items = list(os.walk(path, followlinks=True))
+ items = sorted(items, key=lambda x: natural_sort_key(x[0]))
+
+ for root, _, files in items:
+ for filename in sorted(files, key=natural_sort_key):
+ if allowed_extensions is not None:
+ _, ext = os.path.splitext(filename)
+ if ext not in allowed_extensions:
+ continue
+
+ if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
+ continue
+
+ yield os.path.join(root, filename)
+
+
+def ldm_print(*args, **kwargs):
+ if shared.opts.hide_ldm_prints:
+ return
+
+ print(*args, **kwargs)