aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py59
-rw-r--r--modules/cmd_args.py102
-rw-r--r--modules/extensions.py16
-rw-r--r--modules/generation_parameters_copypaste.py7
-rw-r--r--modules/images.py4
-rw-r--r--modules/mac_specific.py9
-rw-r--r--modules/modelloader.py2
-rw-r--r--modules/paths.py11
-rw-r--r--modules/paths_internal.py22
-rw-r--r--modules/processing.py16
-rw-r--r--modules/scripts.py24
-rw-r--r--modules/scripts_postprocessing.py2
-rw-r--r--modules/sd_hijack_optimizations.py4
-rw-r--r--modules/sd_hijack_unet.py2
-rw-r--r--modules/sd_models.py26
-rw-r--r--modules/shared.py112
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/ui.py71
-rw-r--r--modules/ui_common.py9
-rw-r--r--modules/ui_components.py36
-rw-r--r--modules/ui_extensions.py15
-rw-r--r--modules/ui_extra_networks.py54
22 files changed, 413 insertions, 196 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 35e17afc..13af9ed6 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -6,8 +6,11 @@ import uvicorn
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
-from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
+from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
+from fastapi.exceptions import HTTPException
+from fastapi.responses import JSONResponse
+from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
@@ -18,7 +21,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -90,6 +93,16 @@ def encode_pil_to_base64(image):
return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
+ rich_available = True
+ try:
+ import anyio # importing just so it can be placed on silent list
+ import starlette # importing just so it can be placed on silent list
+ from rich.console import Console
+ console = Console()
+ except:
+ import traceback
+ rich_available = False
+
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
@@ -110,6 +123,36 @@ def api_middleware(app: FastAPI):
))
return res
+ def handle_exception(request: Request, e: Exception):
+ err = {
+ "error": type(e).__name__,
+ "detail": vars(e).get('detail', ''),
+ "body": vars(e).get('body', ''),
+ "errors": str(e),
+ }
+ print(f"API error: {request.method}: {request.url} {err}")
+ if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
+ if rich_available:
+ console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
+ else:
+ traceback.print_exc()
+ return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
+
+ @app.middleware("http")
+ async def exception_handling(request: Request, call_next):
+ try:
+ return await call_next(request)
+ except Exception as e:
+ return handle_exception(request, e)
+
+ @app.exception_handler(Exception)
+ async def fastapi_exception_handler(request: Request, e: Exception):
+ return handle_exception(request, e)
+
+ @app.exception_handler(HTTPException)
+ async def http_exception_handler(request: Request, e: HTTPException):
+ return handle_exception(request, e)
+
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
@@ -150,6 +193,8 @@ class Api:
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
def add_api_route(self, path: str, endpoint, **kwargs):
@@ -412,6 +457,16 @@ class Api:
return {}
+ def unloadapi(self):
+ unload_model_weights()
+
+ return {}
+
+ def reloadapi(self):
+ reload_model_weights()
+
+ return {}
+
def skip(self):
shared.state.skip()
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
new file mode 100644
index 00000000..0af87251
--- /dev/null
+++ b/modules/cmd_args.py
@@ -0,0 +1,102 @@
+import argparse
+import os
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
+parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
+parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
+parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
+parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
+parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
+parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
+parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
+parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
+parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
+parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
+parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
+parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
+parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
+parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
+parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
+parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
+parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
+parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
+parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
+parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
+parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
+parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
+parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
+parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
+parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
+parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
+parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
+parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
+parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
+parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
+parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
+parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
+parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
+parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
+parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
+parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
+parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
+parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
+parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
+parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
+parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
+parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
+parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
+parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
+parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
+parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
+parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
+parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
+parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
+parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
+parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
+parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
+parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
+parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
+parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
+parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
+parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
+parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
+parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
+parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
+parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
+parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
+parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
+parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
+parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
+parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
+parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
+parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
+parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
+parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
+parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
+parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
+parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
+parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
+parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
+parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
+parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
+parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
+parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
+parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
+parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
diff --git a/modules/extensions.py b/modules/extensions.py
index ed4b58fe..8107a933 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -8,11 +8,9 @@ import git
from modules import paths, shared
extensions = []
-extensions_dir = os.path.join(paths.data_path, "extensions")
-extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
-if not os.path.exists(extensions_dir):
- os.makedirs(extensions_dir)
+if not os.path.exists(paths.extensions_dir):
+ os.makedirs(paths.extensions_dir)
def active():
return [x for x in extensions if x.enabled]
@@ -86,11 +84,11 @@ class Extension:
def list_extensions():
extensions.clear()
- if not os.path.isdir(extensions_dir):
+ if not os.path.isdir(paths.extensions_dir):
return
- paths = []
- for dirname in [extensions_dir, extensions_builtin_dir]:
+ extension_paths = []
+ for dirname in [paths.extensions_dir, paths.extensions_builtin_dir]:
if not os.path.isdir(dirname):
return
@@ -99,9 +97,9 @@ def list_extensions():
if not os.path.isdir(path):
continue
- paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+ extension_paths.append((extension_dirname, path, dirname == paths.extensions_builtin_dir))
- for dirname, path, is_builtin in paths:
+ for dirname, path, is_builtin in extension_paths:
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 7c0b5b4e..6df76858 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -401,9 +401,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
button.click(
fn=paste_func,
- _js=f"recalculate_prompts_{tabname}",
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
+ button.click(
+ fn=None,
+ _js=f"recalculate_prompts_{tabname}",
+ inputs=[],
+ outputs=[],
+ )
diff --git a/modules/images.py b/modules/images.py
index 2da988ee..7030aaaa 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -645,6 +645,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
def image_data(data):
+ import gradio as gr
+
try:
image = Image.open(io.BytesIO(data))
textinfo, _ = read_info_from_image(image)
@@ -660,7 +662,7 @@ def image_data(data):
except Exception:
pass
- return '', None
+ return gr.update(), None
def flatten(img, bgcolor):
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 18e6ff72..6fe8dea0 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,4 +1,5 @@
import torch
+import platform
from modules import paths
from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -32,6 +33,10 @@ if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
+ if platform.mac_ver()[0].startswith("13.2."):
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
+ CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
+
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
@@ -49,4 +54,6 @@ if has_mps:
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
-
+ if version.parse(torch.__version__) == version.parse("2.0"):
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e351d808..522affc6 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -4,7 +4,6 @@ import shutil
import importlib
from urllib.parse import urlparse
-from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
+ from basicsr.utils.download_util import load_file_from_url
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:
diff --git a/modules/paths.py b/modules/paths.py
index d991cc71..0e1e00e7 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,16 +1,9 @@
-import argparse
import os
import sys
-import modules.safe
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
-script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+import modules.safe
-# Parse the --data-dir flag first so we can use it as a base for our other argument default values
-parser = argparse.ArgumentParser(add_help=False)
-parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
-cmd_opts_pre = parser.parse_known_args()[0]
-data_path = cmd_opts_pre.data_dir
-models_path = os.path.join(data_path, "models")
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
new file mode 100644
index 00000000..926ec3bb
--- /dev/null
+++ b/modules/paths_internal.py
@@ -0,0 +1,22 @@
+"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
+
+import argparse
+import os
+
+script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+
+sd_configs_path = os.path.join(script_path, "configs")
+sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
+sd_model_file = os.path.join(script_path, 'model.ckpt')
+default_sd_model_file = sd_model_file
+
+# Parse the --data-dir flag first so we can use it as a base for our other argument default values
+parser_pre = argparse.ArgumentParser(add_help=False)
+parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
+cmd_opts_pre = parser_pre.parse_known_args()[0]
+
+data_path = cmd_opts_pre.data_dir
+
+models_path = os.path.join(data_path, "models")
+extensions_dir = os.path.join(data_path, "extensions")
+extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
diff --git a/modules/processing.py b/modules/processing.py
index 59717b4c..2e5a363f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -689,6 +689,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
+ if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+ image_mask = p.mask_for_overlay.convert('RGB')
+ image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
+
+ if opts.save_mask:
+ images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+
+ if opts.save_mask_composite:
+ images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+
+ if opts.return_mask:
+ output_images.append(image_mask)
+
+ if opts.return_mask_composite:
+ output_images.append(image_mask_composite)
+
del x_samples_ddim
devices.torch_gc()
diff --git a/modules/scripts.py b/modules/scripts.py
index 8de19884..d661be4f 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -239,7 +239,15 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
- for scriptfile in sorted(scripts_list):
+ def orderby(basedir):
+ # 1st webui, 2nd extensions-builtin, 3rd extensions
+ priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
+ for key in priority:
+ if basedir.startswith(key):
+ return priority[key]
+ return 9999
+
+ for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
@@ -513,6 +521,18 @@ def reload_scripts():
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+def add_classes_to_gradio_component(comp):
+ """
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+ """
+
+ comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
+
+ if getattr(comp, 'multiselect', False):
+ comp.elem_classes.append('multiselect')
+
+
+
def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
@@ -521,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
res = original_IOComponent_init(self, *args, **kwargs)
+ add_classes_to_gradio_component(self)
+
script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None:
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index ce0ebb61..b11568c0 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
inputs = []
for script in self.scripts_in_preferred_order():
- with gr.Box() as group:
+ with gr.Row() as group:
self.create_script_ui(script, inputs)
script.group = group
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 2e307b5d..372555ff 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype
if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
+ q, k, v = q.float(), k.float(), v.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
@@ -372,7 +372,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype
if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
+ q, k, v = q.float(), k.float(), v.float()
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 843ab66c..15858263 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -67,7 +67,7 @@ def hijack_ddpm_edit():
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
-if version.parse(torch.__version__) <= version.parse("1.13.1"):
+if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f0cb1240..86218c08 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -178,7 +178,7 @@ def select_checkpoint():
return checkpoint_info
-chckpoint_dict_replacements = {
+checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
@@ -186,7 +186,7 @@ chckpoint_dict_replacements = {
def transform_checkpoint_dict_key(k):
- for text, replacement in chckpoint_dict_replacements.items():
+ for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
@@ -494,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
checkpoints_loaded.clear()
- load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model
try:
@@ -517,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
return sd_model
+
+def unload_model_weights(sd_model=None, info=None):
+ from modules import lowvram, devices, sd_hijack
+ timer = Timer()
+
+ if shared.sd_model:
+
+ # shared.sd_model.cond_stage_model.to(devices.cpu)
+ # shared.sd_model.first_stage_model.to(devices.cpu)
+ shared.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ sd_model = None
+ gc.collect()
+ devices.torch_gc()
+ torch.cuda.empty_cache()
+
+ print(f"Unloaded weights {timer.summary()}.")
+
+ return sd_model \ No newline at end of file
diff --git a/modules/shared.py b/modules/shared.py
index f28a12cc..11be3985 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,114 +13,22 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, extensions, script_loading, errors, ui_components, shared_items
-from modules.paths import models_path, script_path, data_path
-
+from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
demo = None
-sd_configs_path = os.path.join(script_path, "configs")
-sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
-sd_model_file = os.path.join(script_path, 'model.ckpt')
-default_sd_model_file = sd_model_file
-
-parser = argparse.ArgumentParser()
-parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
-parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
-parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
-parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
-parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
-parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
-parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
-parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
-parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
-parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
-parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
-parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
-parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
-parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
-parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
-parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
-parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
-parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
-parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
-parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
-parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
-parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
-parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
-parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
-parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
-parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
-parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
-parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
-parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
-parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
-parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
-parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
-parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
-parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
-parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
-parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
-parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
-parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
-parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
-parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
-parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
-parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
-parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
-parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
-parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
-parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
-parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
-parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
-parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
-parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
-parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
-parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
-parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
-parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
-parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
-parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
-parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
-parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
-parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
-parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
-parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
-parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
-parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
-parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
-parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
-parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
-parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
-parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
-parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
-parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
-parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
-parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
-parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
-parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
-parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
-parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
-parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
-parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
-parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
-parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
-parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
-
-
-script_loading.preload_extensions(extensions.extensions_dir, parser)
-script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
+parser = cmd_args.parser
+
+script_loading.preload_extensions(extensions_dir, parser)
+script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
+
restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
@@ -332,6 +240,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
+ "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
+ "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
@@ -448,12 +358,16 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
+ "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"),
+ "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
+ "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c63c7d1d..d2e62e58 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -152,7 +152,11 @@ class EmbeddingDatabase:
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
+ if data:
+ name = data.get('name', name)
+ else:
+ # if data is None, means this is not an embeding, just a preview image
+ return
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
diff --git a/modules/ui.py b/modules/ui.py
index 7e603332..af8546c2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
-from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
+from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -89,7 +89,7 @@ paste_symbol = '\u2199\ufe0f' # ↙
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
-clear_prompt_symbol = '\U0001F5D1' # 🗑️
+clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
@@ -179,14 +179,13 @@ def interrogate_deepbooru(image):
def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row'):
+ with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+ random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
- with gr.Group(elem_id=target_interface + '_subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
@@ -195,8 +194,8 @@ def create_seed_inputs(target_interface):
seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
with FormRow(visible=False) as seed_extra_row_2:
@@ -291,19 +290,19 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
button_interrogate = None
button_deepbooru = None
if is_img2img:
- with gr.Column(scale=1, elem_id="interrogate_col"):
+ with gr.Column(scale=1, elem_classes="interrogate-col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box"):
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+ interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
@@ -325,9 +324,9 @@ def create_toprow(is_img2img):
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
- token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
+ negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
clear_prompt_button.click(
@@ -479,7 +478,9 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
+ with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
+
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
@@ -492,7 +493,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
@@ -586,7 +587,7 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
txt_prompt_img.change(
fn=modules.images.image_data,
@@ -757,7 +758,9 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+ with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
@@ -774,7 +777,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
elif category == "checkboxes":
- with FormRow(elem_id="img2img_checkboxes", variant="compact"):
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
@@ -904,7 +907,7 @@ def create_ui():
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
@@ -1491,11 +1494,33 @@ def create_ui():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ with gr.Row():
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
with gr.TabItem("Licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+
+ def unload_sd_weights():
+ modules.sd_models.unload_model_weights()
+
+ def reload_sd_weights():
+ modules.sd_models.reload_model_weights()
+
+ unload_sd_model.click(
+ fn=unload_sd_weights,
+ inputs=[],
+ outputs=[]
+ )
+
+ reload_sd_model.click(
+ fn=reload_sd_weights,
+ inputs=[],
+ outputs=[]
+ )
request_notifications.click(
fn=lambda: None,
@@ -1598,11 +1623,13 @@ def create_ui():
for i, k, item in quicksettings_list:
component = component_dict[k]
+ info = opts.data_labels[k]
component.change(
fn=lambda value, k=k: run_settings_single(value, key=k),
inputs=[component],
outputs=[component, text_settings],
+ show_progress=info.refresh is not None,
)
text_settings.change(
diff --git a/modules/ui_common.py b/modules/ui_common.py
index a12433d2..0f3427c8 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -129,8 +129,8 @@ Requested path was: {f}
generation_info = None
with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+ with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
+ open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}')
@@ -149,7 +149,7 @@ Requested path was: {f}
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
@@ -160,6 +160,7 @@ Requested path was: {f}
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info, html_info],
outputs=[html_info, html_info],
+ show_progress=False,
)
save.click(
@@ -195,7 +196,7 @@ Requested path was: {f}
else:
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = []
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 284ca0cf..2b1da2cb 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -1,55 +1,61 @@
import gradio as gr
-class ToolButton(gr.Button, gr.components.FormComponent):
- """Small button with single emoji as text, fits inside gradio forms"""
+class FormComponent:
+ def get_expected_parent(self):
+ return gr.components.Form
- def __init__(self, **kwargs):
- super().__init__(variant="tool", **kwargs)
- def get_block_name(self):
- return "button"
+gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
-class ToolButtonTop(gr.Button, gr.components.FormComponent):
- """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+class ToolButton(FormComponent, gr.Button):
+ """Small button with single emoji as text, fits inside gradio forms"""
- def __init__(self, **kwargs):
- super().__init__(variant="tool-top", **kwargs)
+ def __init__(self, *args, **kwargs):
+ classes = kwargs.pop("elem_classes", [])
+ super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
def get_block_name(self):
return "button"
-class FormRow(gr.Row, gr.components.FormComponent):
+class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
-class FormGroup(gr.Group, gr.components.FormComponent):
+class FormColumn(FormComponent, gr.Column):
+ """Same as gr.Column but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "column"
+
+
+class FormGroup(FormComponent, gr.Group):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "group"
-class FormHTML(gr.HTML, gr.components.FormComponent):
+class FormHTML(FormComponent, gr.HTML):
"""Same as gr.HTML but fits inside gradio forms"""
def get_block_name(self):
return "html"
-class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
+class FormColorPicker(FormComponent, gr.ColorPicker):
"""Same as gr.ColorPicker but fits inside gradio forms"""
def get_block_name(self):
return "colorpicker"
-class DropdownMulti(gr.Dropdown):
+class DropdownMulti(FormComponent, gr.Dropdown):
"""Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs)
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 5ddb3fdb..da7e79f0 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,6 +1,5 @@
import json
import os.path
-import shutil
import sys
import time
import traceback
@@ -141,22 +140,20 @@ def install_extension_from_url(dirname, url):
try:
shutil.rmtree(tmpdir, True)
-
- repo = git.Repo.clone_from(url, tmpdir)
- repo.remote().fetch()
-
+ with git.Repo.clone_from(url, tmpdir) as repo:
+ repo.remote().fetch()
+ for submodule in repo.submodules:
+ submodule.update()
try:
os.rename(tmpdir, target_dir)
except OSError as err:
- # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
- # Shouldn't cause any new issues at least but we probably want to handle it there too.
if err.errno == errno.EXDEV:
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
# Since we can't use a rename, do the slower but more versitile shutil.move()
shutil.move(tmpdir, target_dir)
else:
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
- raise(err)
+ raise err
import launch
launch.run_extension_installer(target_dir)
@@ -255,7 +252,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
hidden += 1
continue
- install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
+ install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index cdfd6f2a..daea03d6 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -22,21 +22,37 @@ def register_page(page):
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
-def add_pages_to_demo(app):
- def fetch_file(filename: str = ""):
- from starlette.responses import FileResponse
+def fetch_file(filename: str = ""):
+ from starlette.responses import FileResponse
+
+ if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+
+ ext = os.path.splitext(filename)[1].lower()
+ if ext not in (".png", ".jpg", ".webp"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+
+ # would profit from returning 304
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+
+
+def get_metadata(page: str = "", item: str = ""):
+ from starlette.responses import JSONResponse
- if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
- raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
+ if page is None:
+ return JSONResponse({})
- ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".webp"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+ metadata = page.metadata.get(item)
+ if metadata is None:
+ return JSONResponse({})
- # would profit from returning 304
- return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+ return JSONResponse({"metadata": metadata})
+
+def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
class ExtraNetworksPage:
@@ -45,6 +61,7 @@ class ExtraNetworksPage:
self.name = title.lower()
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
+ self.metadata = {}
def refresh(self):
pass
@@ -66,6 +83,8 @@ class ExtraNetworksPage:
view = shared.opts.extra_networks_default_view
items_html = ''
+ self.metadata = {}
+
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
@@ -86,12 +105,16 @@ class ExtraNetworksPage:
subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f"""
-<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
+<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
{html.escape(subdir if subdir!="" else "all")}
</button>
""" for subdir in subdirs])
for item in self.list_items():
+ metadata = item.get("metadata")
+ if metadata:
+ self.metadata[item["name"]] = metadata
+
items_html += self.create_html_for_item(item, tabname)
if items_html == '':
@@ -124,14 +147,16 @@ class ExtraNetworksPage:
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
+ width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
+ background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
metadata_button = ""
metadata = item.get("metadata")
if metadata:
- metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
- metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
+ metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
args = {
- "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
+ "style": f"'{height}{width}{background_image}'",
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
@@ -215,6 +240,7 @@ def create_ui(container, button, tabname):
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages:
with gr.Tab(page.title):
+
page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem)