aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py49
-rw-r--r--modules/cache.py5
-rw-r--r--modules/call_queue.py5
-rw-r--r--modules/cmd_args.py5
-rw-r--r--modules/config_states.py16
-rw-r--r--modules/fifo_lock.py37
-rw-r--r--modules/gradio_extensons.py25
-rw-r--r--modules/images.py20
-rw-r--r--modules/img2img.py8
-rw-r--r--modules/initialize_util.py19
-rw-r--r--modules/interrogate.py5
-rw-r--r--modules/launch_utils.py4
-rw-r--r--modules/lowvram.py18
-rw-r--r--modules/options.py19
-rw-r--r--modules/patches.py64
-rw-r--r--modules/processing.py57
-rw-r--r--modules/processing_scripts/refiner.py4
-rw-r--r--modules/processing_scripts/seed.py2
-rw-r--r--modules/progress.py53
-rw-r--r--modules/prompt_parser.py2
-rw-r--r--modules/realesrgan_model.py1
-rw-r--r--modules/rng.py2
-rw-r--r--modules/script_callbacks.py26
-rw-r--r--modules/scripts.py16
-rw-r--r--modules/sd_disable_initialization.py63
-rw-r--r--modules/sd_hijack.py16
-rw-r--r--modules/sd_models.py48
-rw-r--r--modules/sd_models_types.py31
-rw-r--r--modules/sd_samplers_cfg_denoiser.py4
-rw-r--r--modules/sd_samplers_common.py18
-rw-r--r--modules/sd_samplers_kdiffusion.py14
-rw-r--r--modules/sd_samplers_timesteps.py9
-rw-r--r--modules/sd_unet.py2
-rw-r--r--modules/sd_vae.py13
-rw-r--r--modules/shared.py9
-rw-r--r--modules/shared_gradio_themes.py3
-rw-r--r--modules/shared_options.py17
-rw-r--r--modules/shared_state.py2
-rw-r--r--modules/ui.py12
-rw-r--r--modules/ui_common.py2
-rw-r--r--modules/ui_components.py12
-rw-r--r--modules/ui_extensions.py226
-rw-r--r--modules/ui_extra_networks_checkpoints.py3
-rw-r--r--modules/ui_tempdir.py2
44 files changed, 703 insertions, 265 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 908c4514..e6edffe7 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -4,6 +4,8 @@ import os
import time
import datetime
import uvicorn
+import ipaddress
+import requests
import gradio as gr
from threading import Lock
from io import BytesIO
@@ -23,8 +25,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
-from modules.sd_vae import vae_dict
+from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -56,7 +57,41 @@ def setUpscalers(req: dict):
return reqDict
+def verify_url(url):
+ """Returns True if the url refers to a global resource."""
+
+ import socket
+ from urllib.parse import urlparse
+ try:
+ parsed_url = urlparse(url)
+ domain_name = parsed_url.netloc
+ host = socket.gethostbyname_ex(domain_name)
+ for ip in host[2]:
+ ip_addr = ipaddress.ip_address(ip)
+ if not ip_addr.is_global:
+ return False
+ except Exception:
+ return False
+
+ return True
+
+
def decode_base64_to_image(encoding):
+ if encoding.startswith("http://") or encoding.startswith("https://"):
+ if not opts.api_enable_requests:
+ raise HTTPException(status_code=500, detail="Requests not allowed")
+
+ if opts.api_forbid_local_requests and not verify_url(encoding):
+ raise HTTPException(status_code=500, detail="Request to local resource not allowed")
+
+ headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
+ response = requests.get(encoding, timeout=30, headers=headers)
+ try:
+ image = Image.open(BytesIO(response.content))
+ return image
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Invalid image url") from e
+
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
@@ -330,6 +365,7 @@ class Api:
with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+ p.is_api = True
p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
@@ -390,6 +426,7 @@ class Api:
with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.is_api = True
p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
@@ -533,7 +570,7 @@ class Api:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():
- shared.opts.set(k, v)
+ shared.opts.set(k, v, is_api=True)
shared.opts.save(shared.config_filename)
return
@@ -565,10 +602,12 @@ class Api:
]
def get_sd_models(self):
- return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+ import modules.sd_models as sd_models
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
def get_sd_vaes(self):
- return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
+ import modules.sd_vae as sd_vae
+ return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/cache.py b/modules/cache.py
index a7cd3aeb..ff26a213 100644
--- a/modules/cache.py
+++ b/modules/cache.py
@@ -30,9 +30,12 @@ def dump_cache():
time.sleep(1)
with cache_lock:
- with open(cache_filename, "w", encoding="utf8") as file:
+ cache_filename_tmp = cache_filename + "-"
+ with open(cache_filename_tmp, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4)
+ os.replace(cache_filename_tmp, cache_filename)
+
dump_cache_after = None
dump_cache_thread = None
diff --git a/modules/call_queue.py b/modules/call_queue.py
index f2eb17d6..ddf0d573 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,11 +1,10 @@
from functools import wraps
import html
-import threading
import time
-from modules import shared, progress, errors, devices
+from modules import shared, progress, errors, devices, fifo_lock
-queue_lock = threading.Lock()
+queue_lock = fifo_lock.FIFOLock()
def wrap_queued_call(func):
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index b0a11538..f0f361bd 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -35,9 +35,10 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
+parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
-parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
+parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
@@ -81,7 +82,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
-parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
+parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
diff --git a/modules/config_states.py b/modules/config_states.py
index 6f1ab53f..b766aef1 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -8,14 +8,12 @@ import time
import tqdm
from datetime import datetime
-from collections import OrderedDict
import git
from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir
-
-all_config_states = OrderedDict()
+all_config_states = {}
def list_config_states():
@@ -28,10 +26,14 @@ def list_config_states():
for filename in os.listdir(config_states_dir):
if filename.endswith(".json"):
path = os.path.join(config_states_dir, filename)
- with open(path, "r", encoding="utf-8") as f:
- j = json.load(f)
- j["filepath"] = path
- config_states.append(j)
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ j = json.load(f)
+ assert "created_at" in j, '"created_at" does not exist'
+ j["filepath"] = path
+ config_states.append(j)
+ except Exception as e:
+ print(f'[ERROR]: Config states {path}, {e}')
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
diff --git a/modules/fifo_lock.py b/modules/fifo_lock.py
new file mode 100644
index 00000000..c35b3ae2
--- /dev/null
+++ b/modules/fifo_lock.py
@@ -0,0 +1,37 @@
+import threading
+import collections
+
+
+# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
+class FIFOLock(object):
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._inner_lock = threading.Lock()
+ self._pending_threads = collections.deque()
+
+ def acquire(self, blocking=True):
+ with self._inner_lock:
+ lock_acquired = self._lock.acquire(False)
+ if lock_acquired:
+ return True
+ elif not blocking:
+ return False
+
+ release_event = threading.Event()
+ self._pending_threads.append(release_event)
+
+ release_event.wait()
+ return self._lock.acquire()
+
+ def release(self):
+ with self._inner_lock:
+ if self._pending_threads:
+ release_event = self._pending_threads.popleft()
+ release_event.set()
+
+ self._lock.release()
+
+ __enter__ = acquire
+
+ def __exit__(self, t, v, tb):
+ self.release()
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
index 77c34c8b..e6b6835a 100644
--- a/modules/gradio_extensons.py
+++ b/modules/gradio_extensons.py
@@ -1,6 +1,7 @@
import gradio as gr
-from modules import scripts, ui_tempdir
+from modules import scripts, ui_tempdir, patches
+
def add_classes_to_gradio_component(comp):
"""
@@ -40,6 +41,8 @@ def Block_get_config(self):
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip
+ config.pop('example_inputs', None)
+
return config
@@ -51,12 +54,20 @@ def BlockContext_init(self, *args, **kwargs):
return res
-original_IOComponent_init = gr.components.IOComponent.__init__
-original_Block_get_config = gr.blocks.Block.get_config
-original_BlockContext_init = gr.blocks.BlockContext.__init__
+def Blocks_get_config_file(self, *args, **kwargs):
+ config = original_Blocks_get_config_file(self, *args, **kwargs)
+
+ for comp_config in config["components"]:
+ if "example_inputs" in comp_config:
+ comp_config["example_inputs"] = {"serialized": []}
+
+ return config
+
+
+original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
+original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
+original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
+original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
-gr.components.IOComponent.__init__ = IOComponent_init
-gr.blocks.Block.get_config = Block_get_config
-gr.blocks.BlockContext.__init__ = BlockContext_init
ui_tempdir.install_ui_tempdir_override()
diff --git a/modules/images.py b/modules/images.py
index 019c1d60..eb644733 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -355,7 +355,9 @@ class FilenameGenerator:
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
- 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
+ 'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
+ 'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
+ 'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
'prompt': lambda self: sanitize_filename_part(self.prompt),
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
@@ -368,7 +370,8 @@ class FilenameGenerator:
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
- 'none': lambda self: '', # Overrides the default so you can get just the sequence number
+ 'none': lambda self: '', # Overrides the default, so you can get just the sequence number
+ 'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
}
default_time_format = '%Y%m%d%H%M%S'
@@ -448,6 +451,14 @@ class FilenameGenerator:
return sanitize_filename_part(formatted_time, replace_spaces=False)
+ def image_hash(self, *args):
+ length = int(args[0]) if (args and args[0] != "") else None
+ return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
+
+ def string_hash(self, text, *args):
+ length = int(args[0]) if (args and args[0] != "") else 8
+ return hashlib.sha256(text.encode()).hexdigest()[0:length]
+
def apply(self, x):
res = ''
@@ -589,6 +600,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
namegen = FilenameGenerator(p, seed, prompt, image)
+ # WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
+ if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
+ print('Image dimensions too large; saving as PNG')
+ extension = ".png"
+
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
diff --git a/modules/img2img.py b/modules/img2img.py
index ac9fd3f8..1519e132 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -122,15 +122,14 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
is_batch = mode == 5
if mode == 0: # img2img
- image = init_img.convert("RGB")
+ image = init_img
mask = None
elif mode == 1: # img2img sketch
- image = sketch.convert("RGB")
+ image = sketch
mask = None
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
- mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
- image = image.convert("RGB")
+ mask = processing.create_binary_mask(mask)
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
orig = inpaint_color_sketch_orig or inpaint_color_sketch
@@ -139,7 +138,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
- image = image.convert("RGB")
elif mode == 4: # inpaint upload mask
image = init_img_inpaint
mask = init_mask_inpaint
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index d8370576..2894eee4 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -132,10 +132,29 @@ def get_gradio_auth_creds():
yield cred
+def dumpstacks():
+ import threading
+ import traceback
+
+ id2name = {th.ident: th.name for th in threading.enumerate()}
+ code = []
+ for threadId, stack in sys._current_frames().items():
+ code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
+ for filename, lineno, name, line in traceback.extract_stack(stack):
+ code.append(f"""File: "{filename}", line {lineno}, in {name}""")
+ if line:
+ code.append(" " + line.strip())
+
+ print("\n".join(code))
+
+
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
+
+ dumpstacks()
+
os._exit(0)
if not os.environ.get("COVERAGE_RUN"):
diff --git a/modules/interrogate.py b/modules/interrogate.py
index a3ae1dd5..3045560d 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -186,9 +186,8 @@ class InterrogateModels:
res = ""
shared.state.begin(job="interrogate")
try:
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- devices.torch_gc()
+ lowvram.send_everything_to_cpu()
+ devices.torch_gc()
self.load()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 449a8755..7e4d5a61 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -246,7 +246,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
- if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions:
+ if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
return []
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
@@ -321,7 +321,7 @@ def prepare_environment():
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
- stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
+ stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 96f52b7b..45701046 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,5 +1,5 @@
import torch
-from modules import devices
+from modules import devices, shared
module_in_gpu = None
cpu = torch.device("cpu")
@@ -14,6 +14,20 @@ def send_everything_to_cpu():
module_in_gpu = None
+def is_needed(sd_model):
+ return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
+
+
+def apply(sd_model):
+ enable = is_needed(sd_model)
+ shared.parallel_processing_allowed = not enable
+
+ if enable:
+ setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
+ else:
+ sd_model.lowvram = False
+
+
def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, 'lowvram', False):
return
@@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
def is_enabled(sd_model):
- return getattr(sd_model, 'lowvram', False)
+ return sd_model.lowvram
diff --git a/modules/options.py b/modules/options.py
index db1fb157..758b1ce5 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -8,7 +8,7 @@ from modules.shared_cmd_options import cmd_opts
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
self.default = default
self.label = label
self.component = component
@@ -26,6 +26,9 @@ class OptionInfo:
self.infotext = infotext
+ self.restrict_api = restrict_api
+ """If True, the setting will not be accessible via API"""
+
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
@@ -71,7 +74,7 @@ options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
class Options:
typemap = {int: float}
- def __init__(self, data_labels, restricted_opts):
+ def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()}
self.restricted_opts = restricted_opts
@@ -113,14 +116,18 @@ class Options:
return super(Options, self).__getattribute__(item)
- def set(self, key, value):
+ def set(self, key, value, is_api=False, run_callbacks=True):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
- if self.data_labels[key].do_not_save:
+ option = self.data_labels[key]
+ if option.do_not_save:
+ return False
+
+ if is_api and option.restrict_api:
return False
try:
@@ -128,9 +135,9 @@ class Options:
except RuntimeError:
return False
- if self.data_labels[key].onchange is not None:
+ if run_callbacks and option.onchange is not None:
try:
- self.data_labels[key].onchange()
+ option.onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
diff --git a/modules/patches.py b/modules/patches.py
new file mode 100644
index 00000000..348235e7
--- /dev/null
+++ b/modules/patches.py
@@ -0,0 +1,64 @@
+from collections import defaultdict
+
+
+def patch(key, obj, field, replacement):
+ """Replaces a function in a module or a class.
+
+ Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
+ If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+ replacement: the new function
+
+ Returns:
+ the original function
+ """
+
+ patch_key = (obj, field)
+ if patch_key in originals[key]:
+ raise RuntimeError(f"patch for {field} is already applied")
+
+ original_func = getattr(obj, field)
+ originals[key][patch_key] = original_func
+
+ setattr(obj, field, replacement)
+
+ return original_func
+
+
+def undo(key, obj, field):
+ """Undoes the peplacement by the patch().
+
+ If the function is not replaced, raises an exception.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+
+ Returns:
+ Always None
+ """
+
+ patch_key = (obj, field)
+
+ if patch_key not in originals[key]:
+ raise RuntimeError(f"there is no patch for {field} to undo")
+
+ original_func = originals[key].pop(patch_key)
+ setattr(obj, field, original_func)
+
+ return None
+
+
+def original(key, obj, field):
+ """Returns the original function for the patch created by the patch() function"""
+ patch_key = (obj, field)
+
+ return originals[key].get(patch_key, None)
+
+
+originals = defaultdict(dict)
diff --git a/modules/processing.py b/modules/processing.py
index 75f1d66f..066351c1 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -81,6 +81,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def create_binary_mask(image):
+ if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.convert('L')
+ return image
def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
@@ -194,6 +200,8 @@ class StableDiffusionProcessing:
sd_vae_name: str = field(default=None, init=False)
sd_vae_hash: str = field(default=None, init=False)
+ is_api: bool = field(default=False, init=False)
+
def __post_init__(self):
if self.sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -258,7 +266,7 @@ class StableDiffusionProcessing:
def setup_scripts(self):
self.scripts_setup_complete = True
- self.scripts.setup_scrips(self)
+ self.scripts.setup_scrips(self, is_ui=not self.is_api)
def comment(self, text):
self.comments[text] = 1
@@ -378,15 +386,20 @@ class StableDiffusionProcessing:
return self.token_merging_ratio or opts.token_merging_ratio
def setup_prompts(self):
- if type(self.prompt) == list:
+ if isinstance(self.prompt,list):
self.all_prompts = self.prompt
+ elif isinstance(self.negative_prompt, list):
+ self.all_prompts = [self.prompt] * len(self.negative_prompt)
else:
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
- if type(self.negative_prompt) == list:
+ if isinstance(self.negative_prompt, list):
self.all_negative_prompts = self.negative_prompt
else:
- self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
+ self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
+
+ if len(self.all_prompts) != len(self.all_negative_prompts):
+ raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
@@ -503,10 +516,10 @@ class Processed:
self.s_noise = p.s_noise
self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
- self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
- self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
+ self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
+ self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
+ self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
@@ -693,17 +706,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
- # after running refiner, the refiner model is not unloaded - webui swaps back to main model here
- if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
- sd_models.reload_model_weights()
-
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
+ # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
for k, v in p.override_settings.items():
- setattr(opts, k, v)
+ opts.set(k, v, is_api=True, run_callbacks=False)
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights()
@@ -732,7 +742,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
- if type(p.prompt) == list:
+ if isinstance(p.prompt, list):
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None
@@ -748,7 +758,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.tiling is None:
p.tiling = opts.tiling
- if p.refiner_checkpoint not in (None, "", "None"):
+ if p.refiner_checkpoint not in (None, "", "None", "none"):
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
if p.refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
@@ -763,12 +773,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_prompts()
- if type(seed) == list:
+ if isinstance(seed, list):
p.all_seeds = seed
else:
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
- if type(subseed) == list:
+ if isinstance(subseed, list):
p.all_subseeds = subseed
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
@@ -1146,6 +1156,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
devices.torch_gc()
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
+ if shared.state.interrupted:
+ return samples
+
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
@@ -1259,12 +1272,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_negative_prompt == '':
self.hr_negative_prompt = self.negative_prompt
- if type(self.hr_prompt) == list:
+ if isinstance(self.hr_prompt, list):
self.all_hr_prompts = self.hr_prompt
else:
self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
- if type(self.hr_negative_prompt) == list:
+ if isinstance(self.hr_negative_prompt, list):
self.all_hr_negative_prompts = self.hr_negative_prompt
else:
self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
@@ -1382,7 +1395,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = self.image_mask
if image_mask is not None:
- image_mask = image_mask.convert('L')
+ # image_mask is passed in as RGBA by Gradio to support alpha masks,
+ # but we still want to support binary masks.
+ image_mask = create_binary_mask(image_mask)
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
@@ -1501,7 +1516,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 3c5b37d2..29ccb78f 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -5,7 +5,7 @@ from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
-class ScriptRefiner(scripts.Script):
+class ScriptRefiner(scripts.ScriptBuiltinUI):
section = "accordions"
create_group = False
@@ -42,7 +42,7 @@ class ScriptRefiner(scripts.Script):
# the actual implementation is in sd_samplers_common.py, apply_refiner
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
- p.refiner_checkpoint_info = None
+ p.refiner_checkpoint = None
p.refiner_switch_at = None
else:
p.refiner_checkpoint = refiner_checkpoint
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index 6ce3b2fc..6b6ff987 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -7,7 +7,7 @@ from modules.shared import cmd_opts
from modules.ui_components import ToolButton
-class ScriptSeed(scripts.ScriptBuiltin):
+class ScriptSeed(scripts.ScriptBuiltinUI):
section = "seed"
create_group = False
diff --git a/modules/progress.py b/modules/progress.py
index f405f07f..69921de7 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -48,6 +48,7 @@ def add_task_to_queue(id_job):
class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+ live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
class ProgressResponse(BaseModel):
@@ -71,7 +72,12 @@ def progressapi(req: ProgressRequest):
completed = req.id_task in finished_tasks
if not active:
- return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
+ textinfo = "Waiting..."
+ if queued:
+ sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
+ queue_index = sorted_queued.index(req.id_task)
+ textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
progress = 0
@@ -89,31 +95,30 @@ def progressapi(req: ProgressRequest):
predicted_duration = elapsed_since_start / progress if progress > 0 else None
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
+ live_preview = None
id_live_preview = req.id_live_preview
- shared.state.set_current_image()
- if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
- image = shared.state.current_image
- if image is not None:
- buffered = io.BytesIO()
-
- if opts.live_previews_image_format == "png":
- # using optimize for large images takes an enormous amount of time
- if max(*image.size) <= 256:
- save_kwargs = {"optimize": True}
+
+ if opts.live_previews_enable and req.live_preview:
+ shared.state.set_current_image()
+ if shared.state.id_live_preview != req.id_live_preview:
+ image = shared.state.current_image
+ if image is not None:
+ buffered = io.BytesIO()
+
+ if opts.live_previews_image_format == "png":
+ # using optimize for large images takes an enormous amount of time
+ if max(*image.size) <= 256:
+ save_kwargs = {"optimize": True}
+ else:
+ save_kwargs = {"optimize": False, "compress_level": 1}
+
else:
- save_kwargs = {"optimize": False, "compress_level": 1}
-
- else:
- save_kwargs = {}
-
- image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
- base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
- live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
- id_live_preview = shared.state.id_live_preview
- else:
- live_preview = None
- else:
- live_preview = None
+ save_kwargs = {}
+
+ image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
+ base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
+ live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
+ id_live_preview = shared.state.id_live_preview
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index e8c41f38..334efeef 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -107,7 +107,7 @@ def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=N
yield args[(step - 1) % len(args)]
def start(self, args):
def flatten(x):
- if type(x) == str:
+ if isinstance(x, str):
yield x
else:
for gen in x:
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 0700b853..02841c30 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -55,6 +55,7 @@ class UpscalerRealESRGAN(Upscaler):
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
+ device=self.device,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
diff --git a/modules/rng.py b/modules/rng.py
index f927a318..9e8ba2ee 100644
--- a/modules/rng.py
+++ b/modules/rng.py
@@ -98,7 +98,7 @@ def slerp(val, low, high):
class ImageRNG:
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
- self.shape = shape
+ self.shape = tuple(map(int, shape))
self.seeds = seeds
self.subseeds = subseeds
self.subseed_strength = subseed_strength
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 77ee55ee..fab23551 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -28,6 +28,15 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+class ExtraNoiseParams:
+ def __init__(self, noise, x):
+ self.noise = noise
+ """Random noise generated by the seed"""
+
+ self.x = x
+ """Latent image representation of the image"""
+
+
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
@@ -100,6 +109,7 @@ callback_map = dict(
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
+ callbacks_extra_noise=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
callbacks_cfg_after_cfg=[],
@@ -189,6 +199,14 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback')
+def extra_noise_callback(params: ExtraNoiseParams):
+ for c in callback_map['callbacks_extra_noise']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'callbacks_extra_noise')
+
+
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']:
try:
@@ -367,6 +385,14 @@ def on_image_saved(callback):
add_callback(callback_map['callbacks_image_saved'], callback)
+def on_extra_noise(callback):
+ """register a function to be called before adding extra noise in img2img or hires fix;
+ The callback is called with one argument:
+ - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
+ """
+ add_callback(callback_map['callbacks_extra_noise'], callback)
+
+
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
diff --git a/modules/scripts.py b/modules/scripts.py
index cbdac2b5..e8518ad0 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -68,6 +68,9 @@ class Script:
on_after_component_elem_id = None
"""list of callbacks to be called after a component with an elem_id is created"""
+ setup_for_ui_only = False
+ """If true, the script setup will only be run in Gradio UI, not in API"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -258,7 +261,6 @@ class Script:
self.on_after_component_elem_id.append((elem_id, callback))
-
def describe(self):
"""unused"""
return ""
@@ -267,7 +269,7 @@ class Script:
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
- tabkind = 'img2img' if self.is_img2img else 'txt2txt'
+ tabkind = 'img2img' if self.is_img2img else 'txt2img'
tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
@@ -280,13 +282,14 @@ class Script:
pass
-class ScriptBuiltin(Script):
+class ScriptBuiltinUI(Script):
+ setup_for_ui_only = True
def elem_id(self, item_id):
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
- tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+ tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
return f'{tabname}{item_id}'
@@ -728,8 +731,11 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
- def setup_scrips(self, p):
+ def setup_scrips(self, p, *, is_ui=True):
for script in self.alwayson_scripts:
+ if not is_ui and script.setup_for_ui_only:
+ continue
+
try:
script_args = p.script_args[script.args_from:script.args_to]
script.setup(p, *script_args)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 695c5736..8863107a 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
```
"""
- def __init__(self, state_dict, device):
+ def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__()
self.state_dict = state_dict
self.device = device
+ self.weight_dtype_conversion = weight_dtype_conversion or {}
+ self.default_dtype = self.weight_dtype_conversion.get('')
+
+ def get_weight_dtype(self, key):
+ key_first_term, _ = key.split('.', 1)
+ return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -167,23 +173,60 @@ class LoadStateDictOnMeta(ReplaceHelper):
sd = self.state_dict
device = self.device
- def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
- params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
+ def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
+ used_param_keys = []
- for name, param in params:
- if param.is_meta:
- self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
+ for name, param in module._parameters.items():
+ if param is None:
+ continue
- original(self, state_dict, prefix, *args, **kwargs)
+ key = prefix + name
+ sd_param = sd.pop(key, None)
+ if sd_param is not None:
+ state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
+ used_param_keys.append(key)
- for name, _ in params:
+ if param.is_meta:
+ dtype = sd_param.dtype if sd_param is not None else param.dtype
+ module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
+
+ for name in module._buffers:
key = prefix + name
- if key in sd:
- del sd[key]
+ sd_param = sd.pop(key, None)
+ if sd_param is not None:
+ state_dict[key] = sd_param
+ used_param_keys.append(key)
+
+ original(module, state_dict, prefix, *args, **kwargs)
+
+ for key in used_param_keys:
+ state_dict.pop(key, None)
+
+ def load_state_dict(original, module, state_dict, strict=True):
+ """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
+ because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
+ all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
+
+ In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
+
+ The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
+ the function and does not call the original) the state dict will just fail to load because weights
+ would be on the meta device.
+ """
+
+ if state_dict == sd:
+ state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
+
+ original(module, state_dict, strict=strict)
+
+ module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
+ module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
+ layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
+ group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 46652fbd..592f0055 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -245,7 +245,21 @@ class StableDiffusionModelHijack:
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
- if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
+ embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
+ conditioner.embedders[i] = embedder.wrapped
+ if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
+ embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
+ conditioner.embedders[i] = embedder.wrapped
+
+ if hasattr(m, 'cond_stage_model'):
+ delattr(m, 'cond_stage_model')
+
+ elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f6fbdcd6..547e93c4 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -343,7 +343,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
- if not shared.cmd_opts.no_half:
+ if shared.cmd_opts.no_half:
+ model.float()
+ devices.dtype_unet = torch.float32
+ timer.record("apply float()")
+ else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
@@ -359,9 +363,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if depth_model:
model.depth_model = depth_model
+ devices.dtype_unet = torch.float16
timer.record("apply half()")
- devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -482,8 +486,12 @@ class SdModelData:
return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae.base_vae = getattr(v, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
@@ -510,7 +518,7 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
@@ -518,10 +526,17 @@ def send_model_to_cpu(m):
devices.torch_gc()
-def send_model_to_device(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+def model_target_device(m):
+ if lowvram.is_needed(m):
+ return devices.cpu
else:
+ return devices.device
+
+
+def send_model_to_device(m):
+ lowvram.apply(m)
+
+ if not m.lowvram:
m.to(shared.device)
@@ -579,7 +594,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("create model")
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
+ else:
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
@@ -642,13 +665,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
send_model_to_device(already_loaded)
timer.record("send model to device")
- model_data.set_sd_model(already_loaded)
+ model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
@@ -660,6 +684,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
@@ -716,7 +744,7 @@ def reload_model_weights(sd_model=None, info=None):
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py
new file mode 100644
index 00000000..5ffd2f4f
--- /dev/null
+++ b/modules/sd_models_types.py
@@ -0,0 +1,31 @@
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+ from modules.sd_models import CheckpointInfo
+
+
+class WebuiSdModel(LatentDiffusion):
+ """This class is not actually instantinated, but its fields are created and fieeld by webui"""
+
+ lowvram: bool
+ """True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
+
+ sd_model_hash: str
+ """short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
+
+ sd_model_checkpoint: str
+ """path to the file on disk that model weights were obtained from"""
+
+ sd_checkpoint_info: 'CheckpointInfo'
+ """structure with additional information about the file with model's weights"""
+
+ is_sdxl: bool
+ """True if the model's architecture is SDXL"""
+
+ is_sd2: bool
+ """True if the model's architecture is SD 2.x"""
+
+ is_sd1: bool
+ """True if the model's architecture is SD 1.x"""
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index bc9b97e4..b8101d38 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -165,7 +165,7 @@ class CFGDenoiser(torch.nn.Module):
else:
cond_in = catenate_conds([tensor, uncond])
- if shared.batch_cond_uncond:
+ if shared.opts.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
@@ -175,7 +175,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 07fc4434..60fa161c 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -35,22 +35,27 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD":
def samples_to_images_tensor(sample, approximation=None, model=None):
- '''latents -> images [-1, 1]'''
- if approximation is None:
+ """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
+
+ if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
+ from modules import lowvram
+ if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
+ approximation = 1
+
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
- x_sample = sample * 1.5
- x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
+ x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else:
if model is None:
model = shared.sd_model
- x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
+ with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
+ x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
@@ -217,6 +222,7 @@ class Sampler:
self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'
+ self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key
@@ -273,7 +279,7 @@ class Sampler:
extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
+ if self.eta != self.eta_default:
p.extra_generation_params[self.eta_infotext_field] = self.eta
extra_params_kwargs['eta'] = self.eta
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 67853ff1..b9e0d577 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -3,6 +3,7 @@ import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
+from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
@@ -16,8 +17,8 @@ samplers_k_diffusion = [
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
@@ -34,7 +35,7 @@ samplers_k_diffusion = [
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
- ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
+ ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
]
@@ -145,6 +146,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
xi = x + noise * sigma_sched[0]
+ if opts.img2img_extra_noise > 0:
+ p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
+ extra_noise_params = ExtraNoiseParams(noise, x)
+ extra_noise_callback(extra_noise_params)
+ noise = extra_noise_params.noise
+ xi += noise * opts.img2img_extra_noise
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index c1f534ed..7a6cbd46 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -3,6 +3,7 @@ import inspect
import sys
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
+from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
@@ -76,6 +77,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_option_field = 'eta_ddim'
self.eta_infotext_field = 'Eta DDIM'
+ self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
@@ -103,6 +105,13 @@ class CompVisSampler(sd_samplers_common.Sampler):
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
+ if opts.img2img_extra_noise > 0:
+ p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
+ extra_noise_params = ExtraNoiseParams(noise, x)
+ extra_noise_callback(extra_noise_params)
+ noise = extra_noise_params.noise
+ xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 6d708ad2..5525cfbc 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -47,7 +47,7 @@ def apply_unet(option=None):
if current_unet_option is None:
current_unet = None
- if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device)
return
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index fd9a1c2a..669097da 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -70,7 +70,6 @@ def get_filename(filepath):
def refresh_vae_list():
- global vae_dict
vae_dict.clear()
paths = [
@@ -104,7 +103,7 @@ def refresh_vae_list():
name = get_filename(filepath)
vae_dict[name] = filepath
- vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
+ vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
def find_vae_near_checkpoint(checkpoint_file):
@@ -160,7 +159,7 @@ def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
- if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
+ if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic):
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
return VaeResolution(resolved=False)
@@ -193,7 +192,7 @@ def load_vae_dict(filename, map_location):
def load_vae(model, vae_file=None, vae_source="from unknown source"):
- global vae_dict, loaded_vae_file
+ global vae_dict, base_vae, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@@ -231,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
restore_base_vae(model)
loaded_vae_file = vae_file
+ model.base_vae = base_vae
+ model.loaded_vae_file = loaded_vae_file
# don't call this from outside
@@ -262,7 +263,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
if loaded_vae_file == vae_file:
return
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if sd_model.lowvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
@@ -274,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
print("VAE weights loaded.")
diff --git a/modules/shared.py b/modules/shared.py
index d9d01484..63661939 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -2,16 +2,15 @@ import sys
import gradio as gr
-from modules import shared_cmd_options, shared_gradio_themes, options, shared_items
+from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
-from ldm.models.diffusion.ddpm import LatentDiffusion
from modules import util
cmd_opts = shared_cmd_options.cmd_opts
parser = shared_cmd_options.parser
-batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
-parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
+batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
+parallel_processing_allowed = True
styles_filename = cmd_opts.styles_file
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
@@ -40,7 +39,7 @@ options_templates = None
opts = None
restricted_opts = None
-sd_model: LatentDiffusion = None
+sd_model: sd_models_types.WebuiSdModel = None
settings_components = None
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py
index 485e89d5..822db0a9 100644
--- a/modules/shared_gradio_themes.py
+++ b/modules/shared_gradio_themes.py
@@ -36,7 +36,8 @@ gradio_hf_hub_themes = [
"step-3-profit/Midnight-Deep",
"Taithrah/Minimal",
"ysharma/huggingface",
- "ysharma/steampunk"
+ "ysharma/steampunk",
+ "NoCrypt/miku"
]
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 69d9d70a..d1389838 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -111,6 +111,12 @@ options_templates.update(options_section(('system', "System"), {
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
}))
+options_templates.update(options_section(('API', "API"), {
+ "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
+ "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
+ "api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
+}))
+
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
@@ -166,7 +172,8 @@ For img2img, VAE is used to process user's input image before the sampling, and
options_templates.update(options_section(('img2img', "img2img"), {
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}, infotext='Noise multiplier'),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
+ "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
@@ -185,7 +192,8 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
- "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
+ "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
+ "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -232,6 +240,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
+ "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
@@ -279,13 +288,15 @@ options_templates.update(options_section(('ui', "Live previews"), {
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
+ "live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
+ "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
- "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
+ "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
diff --git a/modules/shared_state.py b/modules/shared_state.py
index 3dc9c788..d272ee5b 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -128,7 +128,7 @@ class State:
devices.torch_gc()
def set_current_image(self):
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+ """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
if not shared.parallel_processing_allowed:
return
diff --git a/modules/ui.py b/modules/ui.py
index a6b1f964..2b6a13cb 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,7 +13,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad
from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
-from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
from modules.ui_common import create_refresh_button
from modules.ui_gradio_extensions import reload_javascript
@@ -333,7 +333,7 @@ def create_ui():
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
extra_tabs.__enter__()
- with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False):
+ with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
scripts.scripts_txt2img.prepare_ui()
@@ -549,7 +549,7 @@ def create_ui():
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
extra_tabs.__enter__()
- with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow(equal_height=False):
+ with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
copy_image_destinations = {}
@@ -575,7 +575,7 @@ def create_ui():
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
@@ -583,7 +583,7 @@ def create_ui():
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
@@ -598,7 +598,7 @@ def create_ui():
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 4c035f2a..eddc4bc8 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -132,7 +132,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
generation_info = None
with gr.Column():
diff --git a/modules/ui_components.py b/modules/ui_components.py
index d08b2b99..55979f62 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -20,6 +20,18 @@ class ToolButton(FormComponent, gr.Button):
return "button"
+class ResizeHandleRow(gr.Row):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.elem_classes.append("resize-handle-row")
+
+ def get_block_name(self):
+ return "row"
+
+
class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 15a8b0bf..e0138267 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -65,7 +65,7 @@ def save_config_state(name):
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f:
- json.dump(current_config_state, f)
+ json.dump(current_config_state, f, indent=4)
config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys())
@@ -200,119 +200,129 @@ def update_config_states_table(state_name):
created_date = time.asctime(time.gmtime(config_state["created_at"]))
filepath = config_state.get("filepath", "<unknown>")
- code = f"""<!-- {time.time()} -->"""
-
- webui_remote = config_state["webui"]["remote"] or ""
- webui_branch = config_state["webui"]["branch"]
- webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
- webui_commit_date = config_state["webui"]["commit_date"]
- if webui_commit_date:
- webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
- else:
- webui_commit_date = "<unknown>"
-
- remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
- commit_link = make_commit_link(webui_commit_hash, webui_remote)
- date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
-
- current_webui = config_states.get_webui_config()
-
- style_remote = ""
- style_branch = ""
- style_commit = ""
- if current_webui["remote"] != webui_remote:
- style_remote = STYLE_PRIMARY
- if current_webui["branch"] != webui_branch:
- style_branch = STYLE_PRIMARY
- if current_webui["commit_hash"] != webui_commit_hash:
- style_commit = STYLE_PRIMARY
-
- code += f"""<h2>Config Backup: {config_name}</h2>
- <div><b>Filepath:</b> {filepath}</div>
- <div><b>Created at:</b> {created_date}</div>"""
-
- code += f"""<h2>WebUI State</h2>
- <table id="config_state_webui">
- <thead>
- <tr>
- <th>URL</th>
- <th>Branch</th>
- <th>Commit</th>
- <th>Date</th>
- </tr>
- </thead>
- <tbody>
- <tr>
- <td><label{style_remote}>{remote}</label></td>
- <td><label{style_branch}>{webui_branch}</label></td>
- <td><label{style_commit}>{commit_link}</label></td>
- <td><label{style_commit}>{date_link}</label></td>
- </tr>
- </tbody>
- </table>
- """
-
- code += """<h2>Extension State</h2>
- <table id="config_state_extensions">
- <thead>
- <tr>
- <th>Extension</th>
- <th>URL</th>
- <th>Branch</th>
- <th>Commit</th>
- <th>Date</th>
- </tr>
- </thead>
- <tbody>
- """
-
- ext_map = {ext.name: ext for ext in extensions.extensions}
-
- for ext_name, ext_conf in config_state["extensions"].items():
- ext_remote = ext_conf["remote"] or ""
- ext_branch = ext_conf["branch"] or "<unknown>"
- ext_enabled = ext_conf["enabled"]
- ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
- ext_commit_date = ext_conf["commit_date"]
- if ext_commit_date:
- ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ try:
+ webui_remote = config_state["webui"]["remote"] or ""
+ webui_branch = config_state["webui"]["branch"]
+ webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
+ webui_commit_date = config_state["webui"]["commit_date"]
+ if webui_commit_date:
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
else:
- ext_commit_date = "<unknown>"
+ webui_commit_date = "<unknown>"
- remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
- commit_link = make_commit_link(ext_commit_hash, ext_remote)
- date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+ remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
+
+ current_webui = config_states.get_webui_config()
- style_enabled = ""
style_remote = ""
style_branch = ""
style_commit = ""
- if ext_name in ext_map:
- current_ext = ext_map[ext_name]
- current_ext.read_info_from_repo()
- if current_ext.enabled != ext_enabled:
- style_enabled = STYLE_PRIMARY
- if current_ext.remote != ext_remote:
- style_remote = STYLE_PRIMARY
- if current_ext.branch != ext_branch:
- style_branch = STYLE_PRIMARY
- if current_ext.commit_hash != ext_commit_hash:
- style_commit = STYLE_PRIMARY
-
- code += f"""
- <tr>
- <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
- <td><label{style_remote}>{remote}</label></td>
- <td><label{style_branch}>{ext_branch}</label></td>
- <td><label{style_commit}>{commit_link}</label></td>
- <td><label{style_commit}>{date_link}</label></td>
- </tr>
- """
-
- code += """
- </tbody>
- </table>
- """
+ if current_webui["remote"] != webui_remote:
+ style_remote = STYLE_PRIMARY
+ if current_webui["branch"] != webui_branch:
+ style_branch = STYLE_PRIMARY
+ if current_webui["commit_hash"] != webui_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code = f"""<!-- {time.time()} -->
+<h2>Config Backup: {config_name}</h2>
+<div><b>Filepath:</b> {filepath}</div>
+<div><b>Created at:</b> {created_date}</div>
+<h2>WebUI State</h2>
+<table id="config_state_webui">
+ <thead>
+ <tr>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td>
+ <label{style_remote}>{remote}</label>
+ </td>
+ <td>
+ <label{style_branch}>{webui_branch}</label>
+ </td>
+ <td>
+ <label{style_commit}>{commit_link}</label>
+ </td>
+ <td>
+ <label{style_commit}>{date_link}</label>
+ </td>
+ </tr>
+ </tbody>
+</table>
+<h2>Extension State</h2>
+<table id="config_state_extensions">
+ <thead>
+ <tr>
+ <th>Extension</th>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+"""
+
+ ext_map = {ext.name: ext for ext in extensions.extensions}
+
+ for ext_name, ext_conf in config_state["extensions"].items():
+ ext_remote = ext_conf["remote"] or ""
+ ext_branch = ext_conf["branch"] or "<unknown>"
+ ext_enabled = ext_conf["enabled"]
+ ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
+ ext_commit_date = ext_conf["commit_date"]
+ if ext_commit_date:
+ ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ else:
+ ext_commit_date = "<unknown>"
+
+ remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
+ commit_link = make_commit_link(ext_commit_hash, ext_remote)
+ date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+
+ style_enabled = ""
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if ext_name in ext_map:
+ current_ext = ext_map[ext_name]
+ current_ext.read_info_from_repo()
+ if current_ext.enabled != ext_enabled:
+ style_enabled = STYLE_PRIMARY
+ if current_ext.remote != ext_remote:
+ style_remote = STYLE_PRIMARY
+ if current_ext.branch != ext_branch:
+ style_branch = STYLE_PRIMARY
+ if current_ext.commit_hash != ext_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f""" <tr>
+ <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
+ <td><label{style_remote}>{remote}</label></td>
+ <td><label{style_branch}>{ext_branch}</label></td>
+ <td><label{style_commit}>{commit_link}</label></td>
+ <td><label{style_commit}>{date_link}</label></td>
+ </tr>
+"""
+
+ code += """ </tbody>
+</table>"""
+
+ except Exception as e:
+ print(f"[ERROR]: Config states {filepath}, {e}")
+ code = f"""<!-- {time.time()} -->
+<h2>Config Backup: {config_name}</h2>
+<div><b>Filepath:</b> {filepath}</div>
+<div><b>Created at:</b> {created_date}</div>
+<h2>This file is corrupted</h2>"""
return code
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index ebb5249f..ca6c2607 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -30,7 +30,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
- for index, name in enumerate(sd_models.checkpoints_list):
+ names = list(sd_models.checkpoints_list)
+ for index, name in enumerate(names):
yield self.create_item(name, index)
def allowed_directories_for_previews(self):
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 506017e5..85015db5 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -44,6 +44,8 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
+ else:
+ os.makedirs(dir, exist_ok=True)
use_metadata = False
metadata = PngImagePlugin.PngInfo()