aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py15
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/call_queue.py23
-rw-r--r--modules/cmd_args.py4
-rw-r--r--modules/codeformer_model.py11
-rw-r--r--modules/config_states.py13
-rw-r--r--modules/devices.py18
-rw-r--r--modules/errors.py20
-rw-r--r--modules/extensions.py27
-rw-r--r--modules/extra_networks.py2
-rw-r--r--modules/generation_parameters_copypaste.py18
-rw-r--r--modules/gfpgan_model.py7
-rw-r--r--modules/gitpython_hack.py42
-rw-r--r--modules/hypernetworks/hypernetwork.py15
-rw-r--r--modules/images.py52
-rw-r--r--modules/img2img.py3
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/launch_utils.py331
-rw-r--r--modules/localization.py6
-rw-r--r--modules/paths.py1
-rw-r--r--modules/processing.py17
-rw-r--r--modules/realesrgan_model.py16
-rw-r--r--modules/safe.py27
-rw-r--r--modules/script_callbacks.py29
-rw-r--r--modules/script_loading.py7
-rw-r--r--modules/scripts.py152
-rw-r--r--modules/sd_hijack.py37
-rw-r--r--modules/sd_hijack_optimizations.py12
-rw-r--r--modules/sd_models.py26
-rw-r--r--modules/sd_samplers_kdiffusion.py64
-rw-r--r--modules/sd_unet.py92
-rw-r--r--modules/shared.py35
-rw-r--r--modules/shared_items.py38
-rw-r--r--modules/textual_inversion/image_embedding.py21
-rw-r--r--modules/textual_inversion/textual_inversion.py35
-rw-r--r--modules/ui.py394
-rw-r--r--modules/ui_common.py32
-rw-r--r--modules/ui_extensions.py20
-rw-r--r--modules/ui_gradio_extensions.py69
-rw-r--r--modules/ui_settings.py263
-rw-r--r--modules/ui_tempdir.py15
-rw-r--r--modules/upscaler.py4
42 files changed, 1377 insertions, 644 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index eee99bbb..d34ab422 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -108,7 +109,6 @@ def api_middleware(app: FastAPI):
from rich.console import Console
console = Console()
except Exception:
- import traceback
rich_available = False
@app.middleware("http")
@@ -139,11 +139,12 @@ def api_middleware(app: FastAPI):
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
- print(f"API error: {request.method}: {request.url} {err}")
+ message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
+ print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
- traceback.print_exc()
+ errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
@@ -189,6 +190,7 @@ class Api:
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
@@ -541,6 +543,9 @@ class Api:
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+ def get_sd_vaes(self):
+ return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
+
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
@@ -700,4 +705,4 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
diff --git a/modules/api/models.py b/modules/api/models.py
index 1ff2fb33..47fdede2 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,6 +249,10 @@ class SDModelItem(BaseModel):
filename: str = Field(title="Filename")
config: Optional[str] = Field(title="Config file")
+class SDVaeItem(BaseModel):
+ model_name: str = Field(title="Model Name")
+ filename: str = Field(title="Filename")
+
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 447bb764..53af6d70 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,10 +1,8 @@
import html
-import sys
import threading
-import traceback
import time
-from modules import shared, progress
+from modules import shared, progress, errors
queue_lock = threading.Lock()
@@ -56,16 +54,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
try:
res = list(func(*args, **kwargs))
except Exception as e:
- # When printing out our debug argument list, do not print out more than a MB of text
- max_debug_str_len = 131072 # (1024*1024)/8
-
- print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {args} {kwargs}"
- print(argStr[:max_debug_str_len], file=sys.stderr)
- if len(argStr) > max_debug_str_len:
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
- print(traceback.format_exc(), file=sys.stderr)
+ # When printing out our debug argument list,
+ # do not print out more than a 100 KB of text
+ max_debug_str_len = 131072
+ message = "Error completing request"
+ arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
+ if len(arg_str) > max_debug_str_len:
+ arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
+ errors.report(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
@@ -108,4 +104,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res)
return f
-
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 3eeb84d5..de905caa 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
-parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
+parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
@@ -62,7 +62,7 @@ parser.add_argument("--opt-split-attention-invokeai", action='store_true', help=
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index ececdbae..4260b016 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,13 +1,11 @@
import os
-import sys
-import traceback
import cv2
import torch
import modules.face_restoration
import modules.shared
-from modules import shared, devices, modelloader
+from modules import shared, devices, modelloader, errors
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
@@ -105,8 +103,8 @@ def setup_model(dirname):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
+ except Exception:
+ errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
@@ -135,7 +133,6 @@ def setup_model(dirname):
shared.face_restorers.append(codeformer)
except Exception:
- print("Error setting up CodeFormer:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
diff --git a/modules/config_states.py b/modules/config_states.py
index db65bcdb..6f1ab53f 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
"""
import os
-import sys
-import traceback
import json
import time
import tqdm
@@ -13,7 +11,7 @@ from datetime import datetime
from collections import OrderedDict
import git
-from modules import shared, extensions
+from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir
@@ -53,8 +51,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
webui_remote = None
webui_commit_hash = None
@@ -134,8 +131,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
return
try:
@@ -143,8 +139,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
- print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error restoring webui to commit{webui_commit_hash}")
def restore_extension_config(config):
diff --git a/modules/devices.py b/modules/devices.py
index d8a34a0f..1ed6ffdc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,5 +1,7 @@
import sys
import contextlib
+from functools import lru_cache
+
import torch
from modules import errors
@@ -154,3 +156,19 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
+
+
+@lru_cache
+def first_time_calculation():
+ """
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
+ spends about 2.7 seconds doing that, at least wih NVidia.
+ """
+
+ x = torch.zeros((1, 1)).to(device, dtype)
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
+ linear(x)
+
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
+ conv2d(x)
diff --git a/modules/errors.py b/modules/errors.py
index f6b80dbb..e408f500 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -1,7 +1,19 @@
import sys
+import textwrap
import traceback
+def report(message: str, *, exc_info: bool = False) -> None:
+ """
+ Print an error message to stderr, with optional traceback.
+ """
+ for line in message.splitlines():
+ print("***", line, file=sys.stderr)
+ if exc_info:
+ print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
+ print("---", file=sys.stderr)
+
+
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
@@ -12,9 +24,13 @@ def print_error_explanation(message):
print('=' * max_len, file=sys.stderr)
-def display(e: Exception, task):
+def display(e: Exception, task, *, full_traceback=False):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ te = traceback.TracebackException.from_exception(e)
+ if full_traceback:
+ # include frames leading up to the try-catch block
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
+ print(*te.format(), sep="", file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
diff --git a/modules/extensions.py b/modules/extensions.py
index 359a7aa5..8608584b 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,11 +1,8 @@
import os
-import sys
import threading
-import traceback
-import git
-
-from modules import shared
+from modules import shared, errors
+from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
@@ -54,10 +51,9 @@ class Extension:
repo = None
try:
if os.path.exists(os.path.join(self.path, ".git")):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
except Exception:
- print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
if repo is None or repo.bare:
self.remote = None
@@ -65,14 +61,15 @@ class Extension:
try:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
- self.commit_date = repo.head.commit.committed_date
+ commit = repo.head.commit
+ self.commit_date = commit.committed_date
if repo.active_branch:
self.branch = repo.active_branch.name
- self.commit_hash = repo.head.commit.hexsha
- self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
+ self.commit_hash = commit.hexsha
+ self.version = self.commit_hash[:8]
- except Exception as ex:
- print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
+ except Exception:
+ errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None
self.have_info_from_repo = True
@@ -93,7 +90,7 @@ class Extension:
return res
def check_updates(self):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
@@ -115,7 +112,7 @@ class Extension:
self.status = "latest"
def fetch_and_reset_hard(self, commit='origin'):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 34a3ba63..f4743928 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -26,7 +26,7 @@ class ExtraNetworkParams:
self.named = {}
for item in self.items:
- parts = item.split('=', 2)
+ parts = item.split('=', 2) if isinstance(item, str) else [item]
if len(parts) == 2:
self.named[parts[0]] = parts[1]
else:
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d5f0a49b..071bd9ea 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -35,7 +35,7 @@ def reset():
def quote(text):
- if ',' not in str(text) and '\n' not in str(text):
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
return text
return json.dumps(text, ensure_ascii=False)
@@ -306,6 +306,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "RNG" not in res:
res["RNG"] = "GPU"
+ if "Schedule type" not in res:
+ res["Schedule type"] = "Automatic"
+
+ if "Schedule max sigma" not in res:
+ res["Schedule max sigma"] = 0
+
+ if "Schedule min sigma" not in res:
+ res["Schedule min sigma"] = 0
+
+ if "Schedule rho" not in res:
+ res["Schedule rho"] = 0
+
return res
@@ -318,6 +330,10 @@ infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
+ ('Schedule type', 'k_sched_type'),
+ ('Schedule max sigma', 'sigma_max'),
+ ('Schedule min sigma', 'sigma_min'),
+ ('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 0131dea4..e239a09d 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,12 +1,10 @@
import os
-import sys
-import traceback
import facexlib
import gfpgan
import modules.face_restoration
-from modules import paths, shared, devices, modelloader
+from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN"
user_path = None
@@ -112,5 +110,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
- print("Error setting up GFPGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py
new file mode 100644
index 00000000..e537c1df
--- /dev/null
+++ b/modules/gitpython_hack.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+import io
+import subprocess
+
+import git
+
+
+class Git(git.Git):
+ """
+ Git subclassed to never use persistent processes.
+ """
+
+ def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
+ raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
+
+ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
+ ret = subprocess.check_output(
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
+ input=self._prepare_ref(ref),
+ cwd=self._working_dir,
+ timeout=2,
+ )
+ return self._parse_object_header(ret)
+
+ def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
+ # Not really streaming, per se; this buffers the entire object in memory.
+ # Shouldn't be a problem for our use case, since we're only using this for
+ # object headers (commit objects).
+ ret = subprocess.check_output(
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
+ input=self._prepare_ref(ref),
+ cwd=self._working_dir,
+ timeout=30,
+ )
+ bio = io.BytesIO(ret)
+ hexsha, typename, size = self._parse_object_header(bio.readline())
+ return (hexsha, typename, size, self.CatFileContentStream(size, bio))
+
+
+class Repo(git.Repo):
+ GitCommandWrapperType = Git
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 570b5603..5d12b449 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -2,8 +2,6 @@ import datetime
import glob
import html
import os
-import sys
-import traceback
import inspect
import modules.textual_inversion.dataset
@@ -11,7 +9,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -325,17 +323,14 @@ def load_hypernetwork(name):
if path is None:
return None
- hypernetwork = Hypernetwork()
-
try:
+ hypernetwork = Hypernetwork()
hypernetwork.load(path)
+ return hypernetwork
except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading hypernetwork {path}", exc_info=True)
return None
- return hypernetwork
-
def load_hypernetworks(names, multipliers=None):
already_loaded = {}
@@ -770,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Exception in training hypernetwork", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/images.py b/modules/images.py
index 4e8cd993..a12d252b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,6 +1,4 @@
import datetime
-import sys
-import traceback
import pytz
import io
@@ -21,6 +19,8 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
+import modules.sd_vae as sd_vae
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -336,8 +336,20 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator:
+ def get_vae_filename(self): #get the name of the VAE file.
+ if sd_vae.loaded_vae_file is None:
+ return "NoneType"
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
+ split_file_name = file_name.split('.')
+ if len(split_file_name) > 1 and split_file_name[0] == '':
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
+ else:
+ return split_file_name[0]
+
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.image.width,
@@ -354,19 +366,23 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
- 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
- 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
+ 'batch_size': lambda self: self.p.batch_size,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'vae_filename': lambda self: self.get_vae_filename(),
+
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt, image):
+ def __init__(self, p, seed, prompt, image, zip=False):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
+ self.zip = zip
def hasprompt(self, *args):
lower = self.prompt.lower()
@@ -446,8 +462,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -488,14 +503,13 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
image_format = Image.registered_extensions()[extension]
- existing_pnginfo = existing_pnginfo or {}
- if opts.enable_pnginfo:
- existing_pnginfo['parameters'] = geninfo
-
if extension.lower() == '.png':
- pnginfo_data = PngImagePlugin.PngInfo()
- for k, v in (existing_pnginfo or {}).items():
- pnginfo_data.add_text(k, str(v))
+ if opts.enable_pnginfo:
+ pnginfo_data = PngImagePlugin.PngInfo()
+ for k, v in (existing_pnginfo or {}).items():
+ pnginfo_data.add_text(k, str(v))
+ else:
+ pnginfo_data = None
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
@@ -665,9 +679,10 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity']:
+ items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
@@ -678,8 +693,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
diff --git a/modules/img2img.py b/modules/img2img.py
index d704bf90..4c12c2c5 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -92,7 +92,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask).convert('L')
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 111b1322..9b2c5b60 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,6 +1,5 @@
import os
import sys
-import traceback
from collections import namedtuple
from pathlib import Path
import re
@@ -216,8 +215,7 @@ class InterrogateModels:
res += f", {match}"
except Exception:
- print("Error interrogating", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error interrogating", exc_info=True)
res += "<error>"
self.unload()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
new file mode 100644
index 00000000..6e9bb770
--- /dev/null
+++ b/modules/launch_utils.py
@@ -0,0 +1,331 @@
+# this scripts installs necessary requirements and launches main program in webui.py
+import subprocess
+import os
+import sys
+import importlib.util
+import platform
+import json
+from functools import lru_cache
+
+from modules import cmd_args, errors
+from modules.paths_internal import script_path, extensions_dir
+
+args, _ = cmd_args.parser.parse_known_args()
+
+python = sys.executable
+git = os.environ.get('GIT', "git")
+index_url = os.environ.get('INDEX_URL', "")
+dir_repos = "repositories"
+
+# Whether to default to printing command output
+default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
+
+if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+
+
+def check_python_version():
+ is_windows = platform.system() == "Windows"
+ major = sys.version_info.major
+ minor = sys.version_info.minor
+ micro = sys.version_info.micro
+
+ if is_windows:
+ supported_minors = [10]
+ else:
+ supported_minors = [7, 8, 9, 10, 11]
+
+ if not (major == 3 and minor in supported_minors):
+ import modules.errors
+
+ modules.errors.print_error_explanation(f"""
+INCOMPATIBLE PYTHON VERSION
+
+This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
+If you encounter an error with "RuntimeError: Couldn't install torch." message,
+or any other error regarding unsuccessful package (library) installation,
+please downgrade (or upgrade) to the latest version of 3.10 Python
+and delete current Python and "venv" folder in WebUI's directory.
+
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
+
+{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
+
+Use --skip-python-version-check to suppress this warning.
+""")
+
+
+@lru_cache()
+def commit_hash():
+ try:
+ return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
+ except Exception:
+ return "<none>"
+
+
+@lru_cache()
+def git_tag():
+ try:
+ return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
+ except Exception:
+ return "<none>"
+
+
+def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
+ if desc is not None:
+ print(desc)
+
+ run_kwargs = {
+ "args": command,
+ "shell": True,
+ "env": os.environ if custom_env is None else custom_env,
+ "encoding": 'utf8',
+ "errors": 'ignore',
+ }
+
+ if not live:
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
+
+ result = subprocess.run(**run_kwargs)
+
+ if result.returncode != 0:
+ error_bits = [
+ f"{errdesc or 'Error running command'}.",
+ f"Command: {command}",
+ f"Error code: {result.returncode}",
+ ]
+ if result.stdout:
+ error_bits.append(f"stdout: {result.stdout}")
+ if result.stderr:
+ error_bits.append(f"stderr: {result.stderr}")
+ raise RuntimeError("\n".join(error_bits))
+
+ return (result.stdout or "")
+
+
+def is_installed(package):
+ try:
+ spec = importlib.util.find_spec(package)
+ except ModuleNotFoundError:
+ return False
+
+ return spec is not None
+
+
+def repo_dir(name):
+ return os.path.join(script_path, dir_repos, name)
+
+
+def run_pip(command, desc=None, live=default_command_live):
+ if args.skip_install:
+ return
+
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
+ return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
+
+
+def check_run_python(code: str) -> bool:
+ result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
+ return result.returncode == 0
+
+
+def git_clone(url, dir, name, commithash=None):
+ # TODO clone into temporary dir and move if successful
+
+ if os.path.exists(dir):
+ if commithash is None:
+ return
+
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+ if current_hash == commithash:
+ return
+
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+ return
+
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
+
+ if commithash is not None:
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
+
+
+def git_pull_recursive(dir):
+ for subdir, _, _ in os.walk(dir):
+ if os.path.exists(os.path.join(subdir, '.git')):
+ try:
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
+ except subprocess.CalledProcessError as e:
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
+
+
+def version_check(commit):
+ try:
+ import requests
+ commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
+ if commit != "<none>" and commits['commit']['sha'] != commit:
+ print("--------------------------------------------------------")
+ print("| You are not up to date with the most recent release. |")
+ print("| Consider running `git pull` to update. |")
+ print("--------------------------------------------------------")
+ elif commits['commit']['sha'] == commit:
+ print("You are up to date with the most recent release.")
+ else:
+ print("Not a git clone, can't perform version check.")
+ except Exception as e:
+ print("version check failed", e)
+
+
+def run_extension_installer(extension_dir):
+ path_installer = os.path.join(extension_dir, "install.py")
+ if not os.path.isfile(path_installer):
+ return
+
+ try:
+ env = os.environ.copy()
+ env['PYTHONPATH'] = os.path.abspath(".")
+
+ print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
+ except Exception as e:
+ errors.report(str(e))
+
+
+def list_extensions(settings_file):
+ settings = {}
+
+ try:
+ if os.path.isfile(settings_file):
+ with open(settings_file, "r", encoding="utf8") as file:
+ settings = json.load(file)
+ except Exception:
+ errors.report("Could not load settings", exc_info=True)
+
+ disabled_extensions = set(settings.get('disabled_extensions', []))
+ disable_all_extensions = settings.get('disable_all_extensions', 'none')
+
+ if disable_all_extensions != 'none':
+ return []
+
+ return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
+
+
+def run_extensions_installers(settings_file):
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname_extension in list_extensions(settings_file):
+ run_extension_installer(os.path.join(extensions_dir, dirname_extension))
+
+
+def prepare_environment():
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
+
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
+ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
+ clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
+ openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
+
+ stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
+ k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
+ codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
+ blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
+
+ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
+ codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
+ blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
+
+ if not args.skip_python_version_check:
+ check_python_version()
+
+ commit = commit_hash()
+ tag = git_tag()
+
+ print(f"Python {sys.version}")
+ print(f"Version: {tag}")
+ print(f"Commit hash: {commit}")
+
+ if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
+
+ if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
+ raise RuntimeError(
+ 'Torch is not able to use GPU; '
+ 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
+ )
+
+ if not is_installed("gfpgan"):
+ run_pip(f"install {gfpgan_package}", "gfpgan")
+
+ if not is_installed("clip"):
+ run_pip(f"install {clip_package}", "clip")
+
+ if not is_installed("open_clip"):
+ run_pip(f"install {openclip_package}", "open_clip")
+
+ if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
+ if platform.system() == "Windows":
+ if platform.python_version().startswith("3.10"):
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
+ else:
+ print("Installation of xformers is not supported in this version of Python.")
+ print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
+ if not is_installed("xformers"):
+ exit(0)
+ elif platform.system() == "Linux":
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
+
+ if not is_installed("ngrok") and args.ngrok:
+ run_pip("install ngrok", "ngrok")
+
+ os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
+
+ git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
+ git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
+ git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
+ git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
+
+ if not is_installed("lpips"):
+ run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
+
+ if not os.path.isfile(requirements_file):
+ requirements_file = os.path.join(script_path, requirements_file)
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
+
+ run_extensions_installers(settings_file=args.ui_settings_file)
+
+ if args.update_check:
+ version_check(commit)
+
+ if args.update_all_extensions:
+ git_pull_recursive(extensions_dir)
+
+ if "--exit" in sys.argv:
+ print("Exiting because of --exit argument")
+ exit(0)
+
+
+def configure_for_tests():
+ if "--api" not in sys.argv:
+ sys.argv.append("--api")
+ if "--ckpt" not in sys.argv:
+ sys.argv.append("--ckpt")
+ sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
+ if "--skip-torch-cuda-test" not in sys.argv:
+ sys.argv.append("--skip-torch-cuda-test")
+ if "--disable-nan-check" not in sys.argv:
+ sys.argv.append("--disable-nan-check")
+
+ os.environ['COMMANDLINE_ARGS'] = ""
+
+
+def start():
+ print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
+ import webui
+ if '--nowebui' in sys.argv:
+ webui.api_only()
+ else:
+ webui.webui()
diff --git a/modules/localization.py b/modules/localization.py
index ee9c65e7..e8f585da 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,8 +1,7 @@
import json
import os
-import sys
-import traceback
+from modules import errors
localizations = {}
@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
- print(f"Error loading localization from {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/paths.py b/modules/paths.py
index 5f6474c0..5171df4f 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
- (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
diff --git a/modules/processing.py b/modules/processing.py
index 29a3743f..baa9b278 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,5 @@
import json
+import logging
import math
import os
import sys
@@ -13,7 +14,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -23,7 +24,6 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
-import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
@@ -321,14 +321,13 @@ class StableDiffusionProcessing:
have been used before. The second element is where the previously
computed result is stored.
"""
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
return cache[1]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps)
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
return cache[1]
def setup_conds(self):
@@ -589,11 +588,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
def process_images(p: StableDiffusionProcessing) -> Processed:
+ if p.scripts is not None:
+ p.scripts.before_process(p)
+
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
- if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+ override_checkpoint = p.override_settings.get('sd_model_checkpoint')
+ if override_checkpoint is not None and sd_models.checkpoint_alisases.get(override_checkpoint) is None:
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
@@ -674,6 +677,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 99983678..2d27b321 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
import numpy as np
from PIL import Image
@@ -9,7 +7,8 @@ from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
-from modules import modelloader
+from modules import modelloader, errors
+
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
@@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
@@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
return info
- except Exception as e:
- print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ except Exception:
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _):
@@ -135,5 +132,4 @@ def get_realesrgan_models(scaler):
]
return models
except Exception:
- print("Error making Real-ESRGAN models list:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
diff --git a/modules/safe.py b/modules/safe.py
index e8f50774..b1d08a79 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -2,8 +2,6 @@
import pickle
import collections
-import sys
-import traceback
import torch
import numpy
@@ -11,7 +9,10 @@ import _codecs
import zipfile
import re
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+from modules import errors
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
@@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ errors.report(
+ f"Error verifying pickled file from {filename}\n"
+ "-----> !!!! The file is most likely corrupted !!!! <-----\n"
+ "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
+ exc_info=True,
+ )
return None
-
except Exception:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ errors.report(
+ f"Error verifying pickled file from {filename}\n"
+ f"The file may be malicious, so the program is not going to read it.\n"
+ f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
+ exc_info=True,
+ )
return None
return unsafe_torch_load(filename, *args, **kwargs)
@@ -190,4 +194,3 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None
-
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index ecffc206..54824582 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,18 +1,15 @@
-import sys
-import traceback
-from collections import namedtuple
import inspect
+from collections import namedtuple
from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
-from modules import timer
+from modules import errors, timer
def report_exception(c, job):
- print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
class ImageSaveParams:
@@ -113,6 +110,7 @@ callback_map = dict(
callbacks_before_ui=[],
callbacks_on_reload=[],
callbacks_list_optimizers=[],
+ callbacks_list_unets=[],
)
@@ -274,6 +272,18 @@ def list_optimizers_callback():
return res
+def list_unets_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_unets']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_unets')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -433,3 +443,10 @@ def on_list_optimizers(callback):
to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)
+
+
+def on_list_unets(callback):
+ """register a function to be called when UI is making a list of alternative options for unet.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
+
+ add_callback(callback_map['callbacks_list_unets'], callback)
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 57b15862..306a1f35 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,8 +1,8 @@
import os
-import sys
-import traceback
import importlib.util
+from modules import errors
+
def load_module(path):
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser)
except Exception:
- print(f"Error running preload() for {preload_script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running preload() for {preload_script}", exc_info=True)
diff --git a/modules/scripts.py b/modules/scripts.py
index 7ef1a8f8..99bf836a 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,12 +1,11 @@
import os
import re
import sys
-import traceback
from collections import namedtuple
import gradio as gr
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, timer
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
AlwaysVisible = object()
@@ -20,6 +19,9 @@ class Script:
name = None
"""script's internal name derived from title"""
+ section = None
+ """name of UI section that the script's controls will be placed into"""
+
filename = None
args_from = None
args_to = None
@@ -82,6 +84,15 @@ class Script:
pass
+ def before_process(self, p, *args):
+ """
+ This function is called very early before processing begins for AlwaysVisible scripts.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
@@ -264,8 +275,7 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
@@ -281,11 +291,9 @@ def load_scripts():
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
- res = func(*args, **kwargs)
- return res
+ return func(*args, **kwargs)
except Exception:
- print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -298,6 +306,7 @@ class ScriptRunner:
self.titles = []
self.infotext_fields = []
self.paste_field_names = []
+ self.inputs = [None]
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -325,69 +334,73 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
- def setup_ui(self):
+ def create_script_ui(self, script):
import modules.api.models as api_models
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+ script.args_from = len(self.inputs)
+ script.args_to = len(self.inputs)
- inputs = [None]
- inputs_alwayson = [True]
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
- def create_script_ui(script, inputs, inputs_alwayson):
- script.args_from = len(inputs)
- script.args_to = len(inputs)
+ if controls is None:
+ return
- controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+ api_args = []
- if controls is None:
- return
+ for control in controls:
+ control.custom_script_source = os.path.basename(script.filename)
- script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
- api_args = []
+ arg_info = api_models.ScriptArg(label=control.label or "")
- for control in controls:
- control.custom_script_source = os.path.basename(script.filename)
+ for field in ("value", "minimum", "maximum", "step", "choices"):
+ v = getattr(control, field, None)
+ if v is not None:
+ setattr(arg_info, field, v)
- arg_info = api_models.ScriptArg(label=control.label or "")
+ api_args.append(arg_info)
- for field in ("value", "minimum", "maximum", "step", "choices"):
- v = getattr(control, field, None)
- if v is not None:
- setattr(arg_info, field, v)
+ script.api_info = api_models.ScriptInfo(
+ name=script.name,
+ is_img2img=script.is_img2img,
+ is_alwayson=script.alwayson,
+ args=api_args,
+ )
- api_args.append(arg_info)
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
- script.api_info = api_models.ScriptInfo(
- name=script.name,
- is_img2img=script.is_img2img,
- is_alwayson=script.alwayson,
- args=api_args,
- )
+ if script.paste_field_names is not None:
+ self.paste_field_names += script.paste_field_names
- if script.infotext_fields is not None:
- self.infotext_fields += script.infotext_fields
+ self.inputs += controls
+ script.args_to = len(self.inputs)
- if script.paste_field_names is not None:
- self.paste_field_names += script.paste_field_names
+ def setup_ui_for_section(self, section, scriptlist=None):
+ if scriptlist is None:
+ scriptlist = self.alwayson_scripts
- inputs += controls
- inputs_alwayson += [script.alwayson for _ in controls]
- script.args_to = len(inputs)
+ for script in scriptlist:
+ if script.alwayson and script.section != section:
+ continue
- for script in self.alwayson_scripts:
- with gr.Group() as group:
- create_script_ui(script, inputs, inputs_alwayson)
+ with gr.Group(visible=script.alwayson) as group:
+ self.create_script_ui(script)
script.group = group
- dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
- inputs[0] = dropdown
+ def prepare_ui(self):
+ self.inputs = [None]
- for script in self.selectable_scripts:
- with gr.Group(visible=False) as group:
- create_script_ui(script, inputs, inputs_alwayson)
+ def setup_ui(self):
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
- script.group = group
+ self.setup_ui_for_section(None)
+
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
+ self.inputs[0] = dropdown
+
+ self.setup_ui_for_section(None, self.selectable_scripts)
def select_script(script_index):
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
@@ -412,6 +425,7 @@ class ScriptRunner:
)
self.script_load_ctr = 0
+
def onload_script_visibility(params):
title = params.get('Script', None)
if title:
@@ -422,10 +436,10 @@ class ScriptRunner:
else:
return gr.update(visible=False)
- self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
- self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
+ self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
+ self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
- return inputs
+ return self.inputs
def run(self, p, *args):
script_index = args[0]
@@ -445,14 +459,21 @@ class ScriptRunner:
return processed
+ def before_process(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_process(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_process: {script.filename}", exc_info=True)
+
def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running process: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -460,8 +481,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -469,8 +489,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -478,8 +497,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print(f"Error running postprocess: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -487,8 +505,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -496,24 +513,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print(f"Error running before_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print(f"Error running after_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 08d31080..3b6f95ce 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -43,11 +43,16 @@ def list_optimizers():
optimizers.extend(new_optimizers)
-def apply_optimizations():
+def apply_optimizations(option=None):
global current_optimizer
undo_optimizations()
+ if len(optimizers) == 0:
+ # a script can access the model very early, and optimizations would not be filled by then
+ current_optimizer = None
+ return ''
+
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
@@ -55,7 +60,7 @@ def apply_optimizations():
current_optimizer.undo()
current_optimizer = None
- selection = shared.opts.cross_attention_optimization
+ selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
@@ -63,15 +68,19 @@ def apply_optimizations():
if selection == "None":
matching_optimizer = None
+ elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
+ matching_optimizer = None
elif matching_optimizer is None:
matching_optimizer = optimizers[0]
if matching_optimizer is not None:
- print(f"Applying optimization: {matching_optimizer.name}")
+ print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
+ print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
+ print("Disabling attention optimization")
return ''
@@ -149,6 +158,13 @@ class StableDiffusionModelHijack:
def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def apply_optimizations(self, option=None):
+ try:
+ self.optimization_method = apply_optimizations(option)
+ except Exception as e:
+ errors.display(e, "applying cross attention optimization")
+ undo_optimizations()
+
def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
@@ -168,11 +184,7 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
- try:
- self.optimization_method = apply_optimizations()
- except Exception as e:
- errors.display(e, "applying cross attention optimization")
- undo_optimizations()
+ self.apply_optimizations()
self.clip = m.cond_stage_model
@@ -185,6 +197,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
+ ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
+
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -206,6 +223,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 0eb4c525..b41aa419 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,6 +1,5 @@
+from __future__ import annotations
import math
-import sys
-import traceback
import psutil
import torch
@@ -58,7 +57,7 @@ class SdOptimizationSdpNoMem(SdOptimization):
name = "sdp-no-mem"
label = "scaled dot product without memory efficient attention"
cmd_opt = "opt_sdp_no_mem_attention"
- priority = 90
+ priority = 80
def is_available(self):
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
@@ -72,7 +71,7 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
name = "sdp"
label = "scaled dot product"
cmd_opt = "opt_sdp_attention"
- priority = 80
+ priority = 70
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
@@ -115,7 +114,7 @@ class SdOptimizationInvokeAI(SdOptimization):
class SdOptimizationDoggettx(SdOptimization):
name = "Doggettx"
cmd_opt = "opt_split_attention"
- priority = 20
+ priority = 90
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
@@ -139,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print("Cannot import xformers", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Cannot import xformers", exc_info=True)
def get_available_vram():
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b1afbaa7..918f6fd6 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -164,6 +164,7 @@ def model_hash(filename):
def select_checkpoint():
+ """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
@@ -171,14 +172,14 @@ def select_checkpoint():
return checkpoint_info
if len(checkpoints_list) == 0:
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
if shared.cmd_opts.ckpt is not None:
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {model_path}", file=sys.stderr)
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
+ error_message += f"\n - directory {model_path}"
if shared.cmd_opts.ckpt_dir is not None:
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
- exit(1)
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
+ raise FileNotFoundError(error_message)
checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
@@ -313,8 +314,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
timer.record("apply half()")
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
devices.dtype_unet = model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
@@ -423,7 +422,7 @@ class SdModelData:
try:
load_model()
except Exception as e:
- errors.display(e, "loading stable diffusion model")
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
@@ -508,6 +507,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
+ with devices.autocast(), torch.no_grad():
+ sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+
+ timer.record("calculate empty prompt")
+
print(f"Model loaded in {timer.summary()}.")
return sd_model
@@ -527,6 +531,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ sd_unet.apply_unet("None")
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 552c6c64..f8a0c7ba 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -19,7 +19,8 @@ samplers_k_diffusion = [
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
@@ -27,7 +28,8 @@ samplers_k_diffusion = [
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
]
samplers_data_k_diffusion = [
@@ -42,6 +44,14 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
+k_diffusion_scheduler = {
+ 'Automatic': None,
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
+}
+
class CFGDenoiser(torch.nn.Module):
"""
@@ -123,6 +133,16 @@ class CFGDenoiser(torch.nn.Module):
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
+ # TODO add infotext entry
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ elif num_repeats > 0:
+ uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
@@ -228,7 +248,7 @@ class KDiffusionSampler:
self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None
+ self.config = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
@@ -253,6 +273,13 @@ class KDiffusionSampler:
try:
return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
except sd_samplers_common.InterruptedException:
return self.last_latent
@@ -292,6 +319,31 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif opts.k_sched_type != "Automatic":
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
+ sigmas_kwargs = {
+ 'sigma_min': sigma_min,
+ 'sigma_max': sigma_max,
+ }
+
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ p.extra_generation_params["Schedule type"] = opts.k_sched_type
+
+ if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
+ if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
+
+ default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
+
+ if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
+ sigmas_kwargs['rho'] = opts.rho
+ p.extra_generation_params["Schedule rho"] = opts.rho
+
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
@@ -337,13 +389,13 @@ class KDiffusionSampler:
if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigma_sched
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- extra_args={
+ extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
@@ -373,7 +425,7 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
new file mode 100644
index 00000000..6d708ad2
--- /dev/null
+++ b/modules/sd_unet.py
@@ -0,0 +1,92 @@
+import torch.nn
+import ldm.modules.diffusionmodules.openaimodel
+
+from modules import script_callbacks, shared, devices
+
+unet_options = []
+current_unet_option = None
+current_unet = None
+
+
+def list_unets():
+ new_unets = script_callbacks.list_unets_callback()
+
+ unet_options.clear()
+ unet_options.extend(new_unets)
+
+
+def get_unet_option(option=None):
+ option = option or shared.opts.sd_unet
+
+ if option == "None":
+ return None
+
+ if option == "Automatic":
+ name = shared.sd_model.sd_checkpoint_info.model_name
+
+ options = [x for x in unet_options if x.model_name == name]
+
+ option = options[0].label if options else "None"
+
+ return next(iter([x for x in unet_options if x.label == option]), None)
+
+
+def apply_unet(option=None):
+ global current_unet_option
+ global current_unet
+
+ new_option = get_unet_option(option)
+ if new_option == current_unet_option:
+ return
+
+ if current_unet is not None:
+ print(f"Dectivating unet: {current_unet.option.label}")
+ current_unet.deactivate()
+
+ current_unet_option = new_option
+ if current_unet_option is None:
+ current_unet = None
+
+ if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ shared.sd_model.model.diffusion_model.to(devices.device)
+
+ return
+
+ shared.sd_model.model.diffusion_model.to(devices.cpu)
+ devices.torch_gc()
+
+ current_unet = current_unet_option.create_unet()
+ current_unet.option = current_unet_option
+ print(f"Activating unet: {current_unet.option.label}")
+ current_unet.activate()
+
+
+class SdUnetOption:
+ model_name = None
+ """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
+
+ label = None
+ """name of the unet in UI"""
+
+ def create_unet(self):
+ """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
+ raise NotImplementedError()
+
+
+class SdUnet(torch.nn.Module):
+ def forward(self, x, timesteps, context, *args, **kwargs):
+ raise NotImplementedError()
+
+ def activate(self):
+ pass
+
+ def deactivate(self):
+ pass
+
+
+def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+ if current_unet is not None:
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+
diff --git a/modules/shared.py b/modules/shared.py
index 3099d1d2..7025a754 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -6,6 +6,7 @@ import threading
import time
import gradio as gr
+import torch
import tqdm
import modules.interrogate
@@ -43,19 +44,6 @@ restricted_opts = {
"outdir_init_images"
}
-ui_reorder_categories = [
- "inpaint",
- "sampler",
- "checkboxes",
- "hires_fix",
- "dimensions",
- "cfg",
- "seed",
- "batch",
- "override_settings",
- "scripts",
-]
-
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/glass",
@@ -76,6 +64,9 @@ cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_op
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
+devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
+devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
+
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
@@ -314,6 +305,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
+ "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
@@ -403,6 +395,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -414,15 +407,16 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
- "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
- "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -484,9 +478,10 @@ options_templates.update(options_section(('ui', "User interface"), {
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
- "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order").needs_restart(),
+ "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
@@ -515,6 +510,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
@@ -630,6 +629,10 @@ class Options:
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
+ # 1.4.0 ui_reorder
+ if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
+ self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
+
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 2a8713c8..89792e88 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -29,3 +29,41 @@ def cross_attention_optimizations():
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+def sd_unet_items():
+ import modules.sd_unet
+
+ return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
+
+
+def refresh_unet_list():
+ import modules.sd_unet
+
+ modules.sd_unet.list_unets()
+
+
+ui_reorder_categories_builtin_items = [
+ "inpaint",
+ "sampler",
+ "checkboxes",
+ "hires_fix",
+ "dimensions",
+ "cfg",
+ "seed",
+ "batch",
+ "override_settings",
+]
+
+
+def ui_reorder_categories():
+ from modules import scripts
+
+ yield from ui_reorder_categories_builtin_items
+
+ sections = {}
+ for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
+ if isinstance(script.section, str):
+ sections[script.section] = 1
+
+ yield from sections
+
+ yield "scripts"
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 5858a55f..81cff7bf 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -1,8 +1,10 @@
import base64
import json
+import warnings
+
import numpy as np
import zlib
-from PIL import Image, ImageDraw, ImageFont
+from PIL import Image, ImageDraw
import torch
@@ -129,14 +131,17 @@ def extract_image_data_embed(image):
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
+ from modules.images import get_font
+ if textfont:
+ warnings.warn(
+ 'passing in a textfont to caption_image_overlay is deprecated and does nothing',
+ DeprecationWarning,
+ stacklevel=2,
+ )
from math import cos
image = srcimage.copy()
fontsize = 32
- if textfont is None:
- from modules.images import get_font
- textfont = get_font(fontsize)
-
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
for y in range(image.size[1]):
@@ -147,12 +152,12 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
draw = ImageDraw.Draw(image)
- font = ImageFont.truetype(textfont, fontsize)
+ font = get_font(fontsize)
padding = 10
_, _, w, h = draw.textbbox((0, 0), title, font=font)
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
- font = ImageFont.truetype(textfont, fontsize)
+ font = get_font(fontsize)
_, _, w, h = draw.textbbox((0, 0), title, font=font)
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
@@ -163,7 +168,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
+ font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d489ed1e..8da050ca 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
from collections import namedtuple
import torch
@@ -14,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -120,16 +118,29 @@ class EmbeddingDatabase:
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ return self.register_embedding_by_name(embedding, model, embedding.name)
+ def register_embedding_by_name(self, embedding, model, name):
+ ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+ if name in self.word_embeddings:
+ # remove old one from the lookup list
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+ else:
+ lookup = self.ids_lookup[first_id]
+ if embedding is not None:
+ lookup += [(ids, embedding)]
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+ if embedding is None:
+ # unregister embedding with specified name
+ if name in self.word_embeddings:
+ del self.word_embeddings[name]
+ if len(self.ids_lookup[first_id])==0:
+ del self.ids_lookup[first_id]
+ return None
+ self.word_embeddings[name] = embedding
return embedding
def get_expected_shape(self):
@@ -207,8 +218,7 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
@@ -632,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- pass
+ errors.report("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/ui.py b/modules/ui.py
index 5174da63..b7459f08 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -2,20 +2,21 @@ import json
import mimetypes
import os
import sys
-import traceback
from functools import reduce
import warnings
import gradio as gr
-import gradio.routes
import gradio.utils
import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, timer
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
-from modules.paths import script_path, data_path
+from modules.paths import script_path
+from modules.ui_common import create_refresh_button
+from modules.ui_gradio_extensions import reload_javascript
+
from modules.shared import opts, cmd_opts
@@ -35,6 +36,8 @@ import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
import modules.extras
+create_setting_component = ui_settings.create_setting_component
+
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
@@ -231,9 +234,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
+ if gen_info_string:
+ errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
@@ -272,12 +274,12 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
button_interrogate = None
button_deepbooru = None
@@ -368,25 +370,6 @@ def apply_setting(key, value):
return getattr(opts, key)
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
- for k, v in args.items():
- setattr(refresh_component, k, v)
-
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn=refresh,
- inputs=[],
- outputs=[refresh_component]
- )
- return refresh_button
-
-
def create_output_panel(tabname, outdir):
return ui_common.create_output_panel(tabname, outdir)
@@ -405,22 +388,12 @@ def create_sampler_and_steps_selection(choices, tabname):
def ordered_ui_categories():
- user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
- for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
+ for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
-def get_value_for_setting(key):
- value = getattr(opts, key)
-
- info = opts.data_labels[key]
- args = info.component_args() if callable(info.component_args) else info.component_args or {}
- args = {k: v for k, v in args.items() if k not in {'precision'}}
-
- return gr.update(value=value, **args)
-
-
def create_override_settings_dropdown(tabname, row):
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
@@ -456,6 +429,8 @@ def create_ui():
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
+ modules.scripts.scripts_txt2img.prepare_ui()
+
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
@@ -463,8 +438,8 @@ def create_ui():
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="txt2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
@@ -505,10 +480,10 @@ def create_ui():
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
- hr_prompt = gr.Textbox(label="Prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.")
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
- hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.")
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -524,6 +499,9 @@ def create_ui():
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+ else:
+ modules.scripts.scripts_txt2img.setup_ui_for_section(category)
+
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
for component in hr_resolution_preview_inputs:
@@ -616,7 +594,8 @@ def create_ui():
outputs=[
txt2img_prompt,
txt_prompt_img
- ]
+ ],
+ show_progress=False,
)
enable_hr.change(
@@ -779,6 +758,8 @@ def create_ui():
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+ modules.scripts.scripts_img2img.prepare_ui()
+
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
@@ -792,8 +773,8 @@ def create_ui():
with gr.Tab(label="Resize to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
@@ -888,6 +869,8 @@ def create_ui():
inputs=[],
outputs=[inpaint_controls, mask_alpha],
)
+ else:
+ modules.scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
@@ -902,7 +885,8 @@ def create_ui():
outputs=[
img2img_prompt,
img2img_prompt_img
- ]
+ ],
+ show_progress=False,
)
img2img_args = dict(
@@ -1183,8 +1167,8 @@ def create_ui():
with gr.Tab(label="Preprocess images", id="preprocess_images"):
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
- process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_process_width")
- process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_process_height")
+ process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
+ process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
@@ -1276,8 +1260,8 @@ def create_ui():
template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
- training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_training_width")
- training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_training_height")
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
@@ -1460,195 +1444,10 @@ def create_ui():
outputs=[],
)
- def create_setting_component(key, is_quicksettings=False):
- def fun():
- return opts.data[key] if key in opts.data else opts.data_labels[key].default
-
- info = opts.data_labels[key]
- t = type(info.default)
-
- args = info.component_args() if callable(info.component_args) else info.component_args
-
- if info.component is not None:
- comp = info.component
- elif t == str:
- comp = gr.Textbox
- elif t == int:
- comp = gr.Number
- elif t == bool:
- comp = gr.Checkbox
- else:
- raise Exception(f'bad options item type: {t} for key {key}')
-
- elem_id = f"setting_{key}"
-
- if info.refresh is not None:
- if is_quicksettings:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
- else:
- with FormRow():
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
- else:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
-
- return res
-
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
- components = []
- component_dict = {}
- shared.settings_components = component_dict
-
- script_callbacks.ui_settings_callback()
- opts.reorder()
-
- def run_settings(*args):
- changed = []
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp == dummy_component:
- continue
-
- if opts.set(key, value):
- changed.append(key)
-
- try:
- opts.save(shared.config_filename)
- except RuntimeError:
- return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
- return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
-
- def run_settings_single(value, key):
- if not opts.same_type(value, opts.data_labels[key].default):
- return gr.update(visible=True), opts.dumpjson()
-
- if not opts.set(key, value):
- return gr.update(value=getattr(opts, key)), opts.dumpjson()
-
- opts.save(shared.config_filename)
-
- return get_value_for_setting(key), opts.dumpjson()
-
- with gr.Blocks(analytics_enabled=False) as settings_interface:
- with gr.Row():
- with gr.Column(scale=6):
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- with gr.Column():
- restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
-
- result = gr.HTML(elem_id="settings_result")
-
- quicksettings_names = opts.quicksettings_list
- quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
-
- quicksettings_list = []
-
- previous_section = None
- current_tab = None
- current_row = None
- with gr.Tabs(elem_id="settings"):
- for i, (k, item) in enumerate(opts.data_labels.items()):
- section_must_be_skipped = item.section[0] is None
-
- if previous_section != item.section and not section_must_be_skipped:
- elem_id, text = item.section
-
- if current_tab is not None:
- current_row.__exit__()
- current_tab.__exit__()
-
- gr.Group()
- current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
- current_tab.__enter__()
- current_row = gr.Column(variant='compact')
- current_row.__enter__()
-
- previous_section = item.section
-
- if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
- quicksettings_list.append((i, k, item))
- components.append(dummy_component)
- elif section_must_be_skipped:
- components.append(dummy_component)
- else:
- component = create_setting_component(k)
- component_dict[k] = component
- components.append(component)
-
- if current_tab is not None:
- current_row.__exit__()
- current_tab.__exit__()
-
- with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
- loadsave.create_ui()
-
- with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
- download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
- with gr.Row():
- unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
- reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
-
- with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
- gr.HTML(shared.html("licenses.html"), elem_id="licenses")
-
- gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
-
-
- def unload_sd_weights():
- modules.sd_models.unload_model_weights()
-
- def reload_sd_weights():
- modules.sd_models.reload_model_weights()
-
- unload_sd_model.click(
- fn=unload_sd_weights,
- inputs=[],
- outputs=[]
- )
-
- reload_sd_model.click(
- fn=reload_sd_weights,
- inputs=[],
- outputs=[]
- )
-
- request_notifications.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='function(){}'
- )
-
- download_localization.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='download_localization'
- )
-
- def reload_scripts():
- modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
-
- reload_script_bodies.click(
- fn=reload_scripts,
- inputs=[],
- outputs=[]
- )
-
- restart_gradio.click(
- fn=shared.state.request_restart,
- _js='restart_reload',
- inputs=[],
- outputs=[],
- )
+ settings = ui_settings.UiSettings()
+ settings.create_ui(loadsave, dummy_component)
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
@@ -1660,7 +1459,7 @@ def create_ui():
]
interfaces += script_callbacks.ui_tabs_callback()
- interfaces += [(settings_interface, "Settings", "settings")]
+ interfaces += [(settings.interface, "Settings", "settings")]
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
@@ -1670,10 +1469,7 @@ def create_ui():
shared.tab_names.append(label)
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings", variant="compact"):
- for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
- component = create_setting_component(k, is_quicksettings=True)
- component_dict[k] = component
+ settings.add_quicksettings()
parameters_copypaste.connect_paste_params_buttons()
@@ -1704,55 +1500,17 @@ def create_ui():
footer = footer.format(versions=versions_html())
gr.HTML(footer, elem_id="footer")
- text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
- settings_submit.click(
- fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
- inputs=components,
- outputs=[text_settings, result],
- )
-
- for _i, k, _item in quicksettings_list:
- component = component_dict[k]
- info = opts.data_labels[k]
-
- change_handler = component.release if hasattr(component, 'release') else component.change
- change_handler(
- fn=lambda value, k=k: run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, text_settings],
- show_progress=info.refresh is not None,
- )
+ settings.add_functionality(demo)
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
- text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
+ settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
- button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
- button_set_checkpoint.click(
- fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
- _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
- inputs=[component_dict['sd_model_checkpoint'], dummy_component],
- outputs=[component_dict['sd_model_checkpoint'], text_settings],
- )
-
- component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
-
- def get_settings_values():
- return [get_value_for_setting(key) for key in component_keys]
-
- demo.load(
- fn=get_settings_values,
- inputs=[],
- outputs=[component_dict[k] for k in component_keys],
- queue=False,
- )
-
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
except Exception as e:
- print("Error loading/saving model file:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error loading/saving model file", exc_info=True)
modules.sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
@@ -1780,7 +1538,7 @@ def create_ui():
primary_model_name,
secondary_model_name,
tertiary_model_name,
- component_dict['sd_model_checkpoint'],
+ settings.component_dict['sd_model_checkpoint'],
modelmerger_result,
]
)
@@ -1794,70 +1552,6 @@ def create_ui():
return demo
-def webpath(fn):
- if fn.startswith(script_path):
- web_path = os.path.relpath(fn, script_path).replace('\\', '/')
- else:
- web_path = os.path.abspath(fn)
-
- return f'file={web_path}?{os.path.getmtime(fn)}'
-
-
-def javascript_html():
- # Ensure localization is in `window` before scripts
- head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
-
- script_js = os.path.join(script_path, "script.js")
- head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
-
- for script in modules.scripts.list_scripts("javascript", ".js"):
- head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
-
- for script in modules.scripts.list_scripts("javascript", ".mjs"):
- head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
-
- if cmd_opts.theme:
- head += f'<script type="text/javascript">set_theme(\"{cmd_opts.theme}\");</script>\n'
-
- return head
-
-
-def css_html():
- head = ""
-
- def stylesheet(fn):
- return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
-
- for cssfile in modules.scripts.list_files_with_name("style.css"):
- if not os.path.isfile(cssfile):
- continue
-
- head += stylesheet(cssfile)
-
- if os.path.exists(os.path.join(data_path, "user.css")):
- head += stylesheet(os.path.join(data_path, "user.css"))
-
- return head
-
-
-def reload_javascript():
- js = javascript_html()
- css = css_html()
-
- def template_response(*args, **kwargs):
- res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
- res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
- res.init_headers()
- return res
-
- gradio.routes.templates.TemplateResponse = template_response
-
-
-if not hasattr(shared, 'GradioTemplateResponseOriginal'):
- shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
-
-
def versions_html():
import torch
import launch
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 27ab3ebb..57c2d0ad 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -10,8 +10,11 @@ import subprocess as sp
from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text
import modules.images
+from modules.ui_components import ToolButton
+
folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
def update_generation_info(generation_info, html_info, img_index):
@@ -50,9 +53,10 @@ def save_files(js_data, images, do_make_zip, index):
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format
start_index = 0
+ only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
+ only_one = True
images = [images[index]]
start_index = index
@@ -70,6 +74,7 @@ def save_files(js_data, images, do_make_zip, index):
is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
+ p.batch_index = image_index-1
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
@@ -83,7 +88,10 @@ def save_files(js_data, images, do_make_zip, index):
# Make Zip
if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
+ zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
+ namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
+ zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
+ zip_filepath = os.path.join(path, f"{zip_filename}.zip")
from zipfile import ZipFile
with ZipFile(zip_filepath, "w") as zip_file:
@@ -211,3 +219,23 @@ Requested path was: {f}
))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index ef18f438..3140ed64 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,10 +1,8 @@
import json
import os.path
-import sys
import threading
import time
from datetime import datetime
-import traceback
import git
@@ -13,7 +11,7 @@ import html
import shutil
import errno
-from modules import extensions, shared, paths, config_states
+from modules import extensions, shared, paths, config_states, errors
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
@@ -46,8 +44,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
try:
ext.fetch_and_reset_hard()
except Exception:
- print(f"Error getting updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error getting updates for {ext.name}", exc_info=True)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
@@ -113,8 +110,7 @@ def check_updates(id_task, disable_list):
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
- print(f"Error checking updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error checking updates for {ext.name}", exc_info=True)
shared.state.nextjob()
@@ -345,12 +341,12 @@ def install_extension_from_url(dirname, url, branch_name=None):
shutil.rmtree(tmpdir, True)
if not branch_name:
# if no branch is specified, use the default branch
- with git.Repo.clone_from(url, tmpdir) as repo:
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
else:
- with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
@@ -490,8 +486,14 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
def preload_extensions_git_metadata():
+ t0 = time.time()
for extension in extensions.extensions:
extension.read_info_from_repo()
+ print(
+ f"preload_extensions_git_metadata for "
+ f"{len(extensions.extensions)} extensions took "
+ f"{time.time() - t0:.2f}s"
+ )
def create_ui():
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
new file mode 100644
index 00000000..b824b113
--- /dev/null
+++ b/modules/ui_gradio_extensions.py
@@ -0,0 +1,69 @@
+import os
+import gradio as gr
+
+from modules import localization, shared, scripts
+from modules.paths import script_path, data_path
+
+
+def webpath(fn):
+ if fn.startswith(script_path):
+ web_path = os.path.relpath(fn, script_path).replace('\\', '/')
+ else:
+ web_path = os.path.abspath(fn)
+
+ return f'file={web_path}?{os.path.getmtime(fn)}'
+
+
+def javascript_html():
+ # Ensure localization is in `window` before scripts
+ head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
+
+ script_js = os.path.join(script_path, "script.js")
+ head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
+
+ for script in scripts.list_scripts("javascript", ".js"):
+ head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
+
+ for script in scripts.list_scripts("javascript", ".mjs"):
+ head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
+
+ if shared.cmd_opts.theme:
+ head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
+
+ return head
+
+
+def css_html():
+ head = ""
+
+ def stylesheet(fn):
+ return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
+
+ for cssfile in scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ head += stylesheet(cssfile)
+
+ if os.path.exists(os.path.join(data_path, "user.css")):
+ head += stylesheet(os.path.join(data_path, "user.css"))
+
+ return head
+
+
+def reload_javascript():
+ js = javascript_html()
+ css = css_html()
+
+ def template_response(*args, **kwargs):
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
+ res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
+ res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
+ res.init_headers()
+ return res
+
+ gr.routes.templates.TemplateResponse = template_response
+
+
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+ shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
new file mode 100644
index 00000000..7874298e
--- /dev/null
+++ b/modules/ui_settings.py
@@ -0,0 +1,263 @@
+import gradio as gr
+
+from modules import ui_common, shared, script_callbacks, scripts, sd_models
+from modules.call_queue import wrap_gradio_call
+from modules.shared import opts
+from modules.ui_components import FormRow
+from modules.ui_gradio_extensions import reload_javascript
+
+
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
+
+def create_setting_component(key, is_quicksettings=False):
+ def fun():
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
+
+ info = opts.data_labels[key]
+ t = type(info.default)
+
+ args = info.component_args() if callable(info.component_args) else info.component_args
+
+ if info.component is not None:
+ comp = info.component
+ elif t == str:
+ comp = gr.Textbox
+ elif t == int:
+ comp = gr.Number
+ elif t == bool:
+ comp = gr.Checkbox
+ else:
+ raise Exception(f'bad options item type: {t} for key {key}')
+
+ elem_id = f"setting_{key}"
+
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+ else:
+ with FormRow():
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+ else:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+
+ return res
+
+
+class UiSettings:
+ submit = None
+ result = None
+ interface = None
+ components = None
+ component_dict = None
+ dummy_component = None
+ quicksettings_list = None
+ quicksettings_names = None
+ text_settings = None
+
+ def run_settings(self, *args):
+ changed = []
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
+ assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
+ if comp == self.dummy_component:
+ continue
+
+ if opts.set(key, value):
+ changed.append(key)
+
+ try:
+ opts.save(shared.config_filename)
+ except RuntimeError:
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
+
+ def run_settings_single(self, value, key):
+ if not opts.same_type(value, opts.data_labels[key].default):
+ return gr.update(visible=True), opts.dumpjson()
+
+ if not opts.set(key, value):
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
+
+ opts.save(shared.config_filename)
+
+ return get_value_for_setting(key), opts.dumpjson()
+
+ def create_ui(self, loadsave, dummy_component):
+ self.components = []
+ self.component_dict = {}
+ self.dummy_component = dummy_component
+
+ shared.settings_components = self.component_dict
+
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
+ with gr.Row():
+ with gr.Column(scale=6):
+ self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
+
+ self.result = gr.HTML(elem_id="settings_result")
+
+ self.quicksettings_names = opts.quicksettings_list
+ self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
+
+ self.quicksettings_list = []
+
+ previous_section = None
+ current_tab = None
+ current_row = None
+ with gr.Tabs(elem_id="settings"):
+ for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
+
+ if previous_section != item.section and not section_must_be_skipped:
+ elem_id, text = item.section
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ gr.Group()
+ current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
+ current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
+
+ previous_section = item.section
+
+ if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
+ self.quicksettings_list.append((i, k, item))
+ self.components.append(dummy_component)
+ elif section_must_be_skipped:
+ self.components.append(dummy_component)
+ else:
+ component = create_setting_component(k)
+ self.component_dict[k] = component
+ self.components.append(component)
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
+ loadsave.create_ui()
+
+ with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ with gr.Row():
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+
+ with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
+
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+ self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+
+ unload_sd_model.click(
+ fn=sd_models.unload_model_weights,
+ inputs=[],
+ outputs=[]
+ )
+
+ reload_sd_model.click(
+ fn=sd_models.reload_model_weights,
+ inputs=[],
+ outputs=[]
+ )
+
+ request_notifications.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='function(){}'
+ )
+
+ download_localization.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='download_localization'
+ )
+
+ def reload_scripts():
+ scripts.reload_script_body_only()
+ reload_javascript() # need to refresh the html page
+
+ reload_script_bodies.click(
+ fn=reload_scripts,
+ inputs=[],
+ outputs=[]
+ )
+
+ restart_gradio.click(
+ fn=shared.state.request_restart,
+ _js='restart_reload',
+ inputs=[],
+ outputs=[],
+ )
+
+ self.interface = settings_interface
+
+ def add_quicksettings(self):
+ with gr.Row(elem_id="quicksettings", variant="compact"):
+ for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
+ component = create_setting_component(k, is_quicksettings=True)
+ self.component_dict[k] = component
+
+ def add_functionality(self, demo):
+ self.submit.click(
+ fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
+ inputs=self.components,
+ outputs=[self.text_settings, self.result],
+ )
+
+ for _i, k, _item in self.quicksettings_list:
+ component = self.component_dict[k]
+ info = opts.data_labels[k]
+
+ change_handler = component.release if hasattr(component, 'release') else component.change
+ change_handler(
+ fn=lambda value, k=k: self.run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, self.text_settings],
+ show_progress=info.refresh is not None,
+ )
+
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+ button_set_checkpoint.click(
+ fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
+ outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
+ )
+
+ component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
+
+ def get_settings_values():
+ return [get_value_for_setting(key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[self.component_dict[k] for k in component_keys],
+ queue=False,
+ )
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index f05049e1..9fc7d764 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -3,7 +3,7 @@ import tempfile
from collections import namedtuple
from pathlib import Path
-import gradio as gr
+import gradio.components
from PIL import PngImagePlugin
@@ -31,13 +31,16 @@ def check_tmp_file(gradio, filename):
return False
-def save_pil_to_file(pil_image, dir=None):
+def save_pil_to_file(self, pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
+ filename = already_saved_as
- file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
- return file_obj
+ if not shared.opts.save_images_add_number:
+ filename += f'?{os.path.getmtime(already_saved_as)}'
+
+ return filename
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
@@ -51,11 +54,11 @@ def save_pil_to_file(pil_image, dir=None):
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
- return file_obj
+ return file_obj.name
# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
+gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
def on_tmpdir_changed():
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 7b1046d6..3c82861d 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -53,8 +53,8 @@ class Upscaler:
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = int(img.width * scale)
- dest_h = int(img.height * scale)
+ dest_w = round((img.width * scale - 4) / 8) * 8
+ dest_h = round((img.height * scale - 4) / 8) * 8
for _ in range(3):
shape = (img.width, img.height)