aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py142
-rw-r--r--modules/api/models.py30
-rw-r--r--modules/codeformer_model.py2
-rw-r--r--modules/extensions.py6
-rw-r--r--modules/generation_parameters_copypaste.py20
-rw-r--r--modules/images.py15
-rw-r--r--modules/mac_specific.py12
-rw-r--r--modules/memmon.py12
-rw-r--r--modules/modelloader.py10
-rw-r--r--modules/models/diffusion/uni_pc/__init__.py1
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py100
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py857
-rw-r--r--modules/processing.py26
-rw-r--r--modules/script_callbacks.py8
-rw-r--r--modules/scripts.py56
-rw-r--r--modules/scripts_postprocessing.py2
-rw-r--r--modules/sd_hijack.py12
-rw-r--r--modules/sd_hijack_optimizations.py72
-rw-r--r--modules/sd_hijack_unet.py2
-rw-r--r--modules/sd_models.py58
-rw-r--r--modules/sd_samplers.py2
-rw-r--r--modules/sd_samplers_compvis.py59
-rw-r--r--modules/sd_samplers_kdiffusion.py4
-rw-r--r--modules/sd_vae_approx.py5
-rw-r--r--modules/shared.py39
-rw-r--r--modules/textual_inversion/dataset.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/timer.py3
-rw-r--r--modules/ui.py85
-rw-r--r--modules/ui_common.py14
-rw-r--r--modules/ui_components.py36
-rw-r--r--modules/ui_extensions.py17
-rw-r--r--modules/ui_extra_networks.py87
-rw-r--r--modules/ui_extra_networks_checkpoints.py14
-rw-r--r--modules/ui_extra_networks_hypernets.py12
-rw-r--r--modules/ui_extra_networks_textual_inversion.py13
36 files changed, 1658 insertions, 183 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 5a9ac5f1..f52f7fef 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -150,6 +150,9 @@ class Api:
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
@@ -163,47 +166,98 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
- def get_script(self, script_name, script_runner):
- if script_name is None:
+ def get_selectable_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
return None, None
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
-
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx]
return script, script_idx
+
+ def get_scripts_list(self):
+ t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
+ i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
+
+ return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
+
+ def get_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
+ return None, None
+
+ script_idx = script_name_to_index(script_name, script_runner.scripts)
+ return script_runner.scripts[script_idx]
+
+ def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
+ #find max idx from the scripts in runner and generate a none array to init script_args
+ last_arg_index = 1
+ for script in script_runner.scripts:
+ if last_arg_index < script.args_to:
+ last_arg_index = script.args_to
+ # None everywhere except position 0 to initialize script args
+ script_args = [None]*last_arg_index
+ # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
+ if selectable_scripts:
+ script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
+ script_args[0] = selectable_idx + 1
+ else:
+ # when [0] = 0 no selectable script to run
+ script_args[0] = 0
+
+ # Now check for always on scripts
+ if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
+ for alwayson_script_name in request.alwayson_scripts.keys():
+ alwayson_script = self.get_script(alwayson_script_name, script_runner)
+ if alwayson_script == None:
+ raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
+ # Selectable script in always on script param check
+ if alwayson_script.alwayson == False:
+ raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
+ # always on script with no arg should always run so you don't really need to add them to the requests
+ if "args" in request.alwayson_scripts[alwayson_script_name]:
+ script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
+ return script_args
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
- script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
+ script_runner = scripts.scripts_txt2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
- populate = txt2imgreq.copy(update={ # Override __init__ params
+ populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True
- }
- )
+ "do_not_save_samples": not txt2imgreq.save_images,
+ "do_not_save_grid": not txt2imgreq.save_images,
+ })
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = vars(populate)
args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin()
- if script is not None:
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args)
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
- b64images = list(map(encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -212,41 +266,53 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
- script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
-
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
- populate = img2imgreq.copy(update={ # Override __init__ params
+ script_runner = scripts.scripts_img2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(True)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True,
- "mask": mask
- }
- )
+ "do_not_save_samples": not img2imgreq.save_images,
+ "do_not_save_grid": not img2imgreq.save_images,
+ "mask": mask,
+ })
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin()
- if script is not None:
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args)
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
- b64images = list(map(encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
if not img2imgreq.include_init_images:
img2imgreq.init_images = None
@@ -348,6 +414,16 @@ class Api:
return {}
+ def unloadapi(self):
+ unload_model_weights()
+
+ return {}
+
+ def reloadapi(self):
+ reload_model_weights()
+
+ return {}
+
def skip(self):
shared.state.skip()
diff --git a/modules/api/models.py b/modules/api/models.py
index cba43d3b..4a70f440 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
"outpath_samples",
"outpath_grids",
"sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
+ # "do_not_save_samples",
+ # "do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
@@ -100,13 +100,31 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "init_images", "type": list, "default": None},
+ {"key": "denoising_strength", "type": float, "default": 0.75},
+ {"key": "mask", "type": str, "default": None},
+ {"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
).generate_model()
class TextToImageResponse(BaseModel):
@@ -267,3 +285,7 @@ class EmbeddingsResponse(BaseModel):
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
+
+class ScriptsList(BaseModel):
+ txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
+ img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 01fb7bd8..8d84bbc9 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -55,7 +55,7 @@ def setup_model(dirname):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
diff --git a/modules/extensions.py b/modules/extensions.py
index 3eef9eaf..ed4b58fe 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -66,7 +66,7 @@ class Extension:
def check_updates(self):
repo = git.Repo(self.path)
- for fetch in repo.remote().fetch("--dry-run"):
+ for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
@@ -79,8 +79,8 @@ class Extension:
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
- repo.git.fetch('--all')
- repo.git.reset('--hard', 'origin')
+ repo.git.fetch(all=True)
+ repo.git.reset('origin', hard=True)
def list_extensions():
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 89dc23bf..6df76858 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -23,13 +23,14 @@ registered_param_bindings = []
class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
+ self.paste_field_names = paste_field_names
def reset():
@@ -134,7 +135,7 @@ def connect_paste_params_buttons():
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
if binding.source_tabname is not None and fields is not None:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
binding.paste_button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
@@ -288,6 +289,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
settings_map = {}
+
+
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
@@ -296,7 +299,11 @@ infotext_to_setting_name_mapping = [
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
- ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
+ ('UniPC variant', 'uni_pc_variant'),
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ('UniPC order', 'uni_pc_order'),
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
]
@@ -394,9 +401,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
button.click(
fn=paste_func,
- _js=f"recalculate_prompts_{tabname}",
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
+ button.click(
+ fn=None,
+ _js=f"recalculate_prompts_{tabname}",
+ inputs=[],
+ outputs=[],
+ )
diff --git a/modules/images.py b/modules/images.py
index 38404de3..7030aaaa 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif image_to_save.mode == 'I;16':
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
if opts.enable_pnginfo and info is not None:
exif_bytes = piexif.dump({
@@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename)
+ if hasattr(os, 'statvfs'):
+ max_name_len = os.statvfs(path).f_namemax
+ fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
+ params.filename = fullfn_without_extension + extension
+ fullfn = params.filename
_atomically_save_image(image, fullfn_without_extension, extension)
image.already_saved_as = fullfn
@@ -582,9 +587,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
ratio = image.width / image.height
if oversize and ratio > 1:
- image = image.resize((opts.target_side_length, image.height * opts.target_side_length // image.width), LANCZOS)
+ image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
elif oversize:
- image = image.resize((image.width * opts.target_side_length // image.height, opts.target_side_length), LANCZOS)
+ image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
try:
_atomically_save_image(image, fullfn_without_extension, ".jpg")
@@ -640,6 +645,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
def image_data(data):
+ import gradio as gr
+
try:
image = Image.open(io.BytesIO(data))
textinfo, _ = read_info_from_image(image)
@@ -655,7 +662,7 @@ def image_data(data):
except Exception:
pass
- return '', None
+ return gr.update(), None
def flatten(img, bgcolor):
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index ddcea53b..6fe8dea0 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,4 +1,5 @@
import torch
+import platform
from modules import paths
from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -23,7 +24,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
@@ -32,6 +33,10 @@ if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
+ if platform.mac_ver()[0].startswith("13.2."):
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
+ CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
+
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
@@ -45,9 +50,10 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
-
+ if version.parse(torch.__version__) == version.parse("2.0"):
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
diff --git a/modules/memmon.py b/modules/memmon.py
index a7060f58..4018edcc 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
self.data = defaultdict(int)
try:
- torch.cuda.mem_get_info()
+ self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True
+ def cuda_mem_get_info(self):
+ index = self.device.index if self.device.index is not None else torch.cuda.current_device()
+ return torch.cuda.mem_get_info(index)
+
def run(self):
if self.disabled:
return
@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.clear()
continue
- self.data["min_free"] = torch.cuda.mem_get_info()[0]
+ self.data["min_free"] = self.cuda_mem_get_info()[0]
while self.run_flag.is_set():
- free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
+ free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate)
@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
def read(self):
if not self.disabled:
- free, total = torch.cuda.mem_get_info()
+ free, total = self.cuda_mem_get_info()
self.data["free"] = free
self.data["total"] = total
diff --git a/modules/modelloader.py b/modules/modelloader.py
index fc3f6249..522affc6 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -4,9 +4,8 @@ import shutil
import importlib
from urllib.parse import urlparse
-from basicsr.utils.download_util import load_file_from_url
from modules import shared
-from modules.upscaler import Upscaler
+from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
+ from basicsr.utils.download_util import load_file_from_url
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:
@@ -169,4 +169,8 @@ def load_upscalers():
scaler = cls(commandline_options.get(cmd_name, None))
datas += scaler.scalers
- shared.sd_upscalers = datas
+ shared.sd_upscalers = sorted(
+ datas,
+ # Special case for UpscalerNone keeps it at the beginning of the list.
+ key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
+ )
diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py
new file mode 100644
index 00000000..e1265e3f
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/__init__.py
@@ -0,0 +1 @@
+from .sampler import UniPCSampler
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
new file mode 100644
index 00000000..a241c8a7
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -0,0 +1,100 @@
+"""SAMPLING ONLY."""
+
+import torch
+
+from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
+from modules import shared, devices
+
+
+class UniPCSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.before_sample = None
+ self.after_sample = None
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != devices.device:
+ attr = attr.to(devices.device)
+ setattr(self, name, attr)
+
+ def set_hooks(self, before_sample, after_sample, after_update):
+ self.before_sample = before_sample
+ self.after_sample = after_sample
+ self.after_update = after_update
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ # print(f'Data shape for UniPC sampling is {size}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ # SD 1.X is "noise", SD 2.X is "v"
+ model_type = "v" if self.model.parameterization == "v" else "noise"
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=model_type,
+ guidance_type="classifier-free",
+ #condition=conditioning,
+ #unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
+ x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
+
+ return x.to(device), None
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
new file mode 100644
index 00000000..eb5f4e76
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -0,0 +1,857 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm.auto import trange
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0**2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ #condition=None,
+ #unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, None, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input, condition):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous, condition, unconditional_condition):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input, condition)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ if isinstance(condition, dict):
+ assert isinstance(unconditional_condition, dict)
+ c_in = dict()
+ for k in condition:
+ if isinstance(condition[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_condition[k][i],
+ condition[k][i]]) for i in range(len(condition[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_condition[k],
+ condition[k]])
+ elif isinstance(condition, list):
+ c_in = list()
+ assert isinstance(unconditional_condition, list)
+ for i in range(len(condition)):
+ c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
+ else:
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class UniPC:
+ def __init__(
+ self,
+ model_fn,
+ noise_schedule,
+ predict_x0=True,
+ thresholding=False,
+ max_val=1.,
+ variant='bh1',
+ condition=None,
+ unconditional_condition=None,
+ before_sample=None,
+ after_sample=None,
+ after_update=None
+ ):
+ """Construct a UniPC.
+
+ We support both data_prediction and noise_prediction.
+ """
+ self.model_fn_ = model_fn
+ self.noise_schedule = noise_schedule
+ self.variant = variant
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+ self.condition = condition
+ self.unconditional_condition = unconditional_condition
+ self.before_sample = before_sample
+ self.after_sample = after_sample
+ self.after_update = after_update
+
+ def dynamic_thresholding_fn(self, x0, t=None):
+ """
+ The dynamic thresholding method.
+ """
+ dims = x0.dim()
+ p = self.dynamic_thresholding_ratio
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model(self, x, t):
+ cond = self.condition
+ uncond = self.unconditional_condition
+ if self.before_sample is not None:
+ x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
+ res = self.model_fn_(x, t, cond, uncond)
+ if self.after_sample is not None:
+ x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
+
+ if isinstance(res, tuple):
+ # (None, pred_x0)
+ res = res[1]
+
+ return res
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3,] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3,] * (K - 1) + [1]
+ else:
+ orders = [3,] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2,] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2,] * (K - 1) + [1]
+ elif order == 1:
+ K = steps
+ orders = [1,] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
+ if len(t.shape) == 0:
+ t = t.view(-1)
+ if 'bh' in self.variant:
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
+ else:
+ assert self.variant == 'vary_coeff'
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
+
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
+ #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
+ ns = self.noise_schedule
+ assert order <= len(model_prev_list)
+
+ # first compute rks
+ t_prev_0 = t_prev_list[-1]
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
+ lambda_t = ns.marginal_lambda(t)
+ model_prev_0 = model_prev_list[-1]
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h = lambda_t - lambda_prev_0
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ t_prev_i = t_prev_list[-(i + 1)]
+ model_prev_i = model_prev_list[-(i + 1)]
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
+ rk = (lambda_prev_i - lambda_prev_0) / h
+ rks.append(rk)
+ D1s.append((model_prev_i - model_prev_0) / rk)
+
+ rks.append(1.)
+ rks = torch.tensor(rks, device=x.device)
+
+ K = len(rks)
+ # build C matrix
+ C = []
+
+ col = torch.ones_like(rks)
+ for k in range(1, K + 1):
+ C.append(col)
+ col = col * rks / (k + 1)
+ C = torch.stack(C, dim=1)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
+ A_p = C_inv_p
+
+ if use_corrector:
+ #print('using corrector')
+ C_inv = torch.linalg.inv(C)
+ A_c = C_inv
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh)
+ h_phi_ks = []
+ factorial_k = 1
+ h_phi_k = h_phi_1
+ for k in range(1, K + 2):
+ h_phi_ks.append(h_phi_k)
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
+ factorial_k *= (k + 1)
+
+ model_t = None
+ if self.predict_x0:
+ x_t_ = (
+ sigma_t / sigma_prev_0 * x
+ - alpha_t * h_phi_1 * model_prev_0
+ )
+ # now predictor
+ x_t = x_t_
+ if len(D1s) > 0:
+ # compute the residuals for predictor
+ for k in range(K - 1):
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
+ # now corrector
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_
+ k = 0
+ for k in range(K - 1):
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
+ else:
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ x_t_ = (
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
+ - (sigma_t * h_phi_1) * model_prev_0
+ )
+ # now predictor
+ x_t = x_t_
+ if len(D1s) > 0:
+ # compute the residuals for predictor
+ for k in range(K - 1):
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
+ # now corrector
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_
+ k = 0
+ for k in range(K - 1):
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
+ return x_t, model_t
+
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
+ #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
+ ns = self.noise_schedule
+ assert order <= len(model_prev_list)
+ dims = x.dim()
+
+ # first compute rks
+ t_prev_0 = t_prev_list[-1]
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
+ lambda_t = ns.marginal_lambda(t)
+ model_prev_0 = model_prev_list[-1]
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h = lambda_t - lambda_prev_0
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ t_prev_i = t_prev_list[-(i + 1)]
+ model_prev_i = model_prev_list[-(i + 1)]
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
+ rks.append(rk)
+ D1s.append((model_prev_i - model_prev_0) / rk)
+
+ rks.append(1.)
+ rks = torch.tensor(rks, device=x.device)
+
+ R = []
+ b = []
+
+ hh = -h[0] if self.predict_x0 else h[0]
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.variant == 'bh1':
+ B_h = hh
+ elif self.variant == 'bh2':
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= (i + 1)
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=x.device)
+
+ # now predictor
+ use_predictor = len(D1s) > 0 and x_t is None
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ if x_t is None:
+ # for order 2, we use a simplified version
+ if order == 2:
+ rhos_p = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
+ else:
+ D1s = None
+
+ if use_corrector:
+ #print('using corrector')
+ # for order 1, we use a simplified version
+ if order == 1:
+ rhos_c = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_c = torch.linalg.solve(R, b)
+
+ model_t = None
+ if self.predict_x0:
+ x_t_ = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
+ )
+
+ if x_t is None:
+ if use_predictor:
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
+
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ if D1s is not None:
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
+ else:
+ x_t_ = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
+ )
+ if x_t is None:
+ if use_predictor:
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
+
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ if D1s is not None:
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
+ return x_t, model_t
+
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05, corrector=False,
+ ):
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'multistep':
+ assert steps >= order, "UniPC order must be < sampling steps"
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in range(1, order):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ model_prev_list.append(model_x)
+ t_prev_list.append(vec_t)
+ for step in trange(order, steps + 1):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ #print('this step order:', step_order)
+ if step == steps:
+ #print('do not run corrector at the last step')
+ use_corrector = False
+ else:
+ use_corrector = True
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ model_prev_list[-1] = model_x
+ else:
+ raise NotImplementedError()
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,)*(dims - 1)]
diff --git a/modules/processing.py b/modules/processing.py
index 2009d3bf..2e5a363f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -583,6 +583,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
+ extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -597,6 +598,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ if p.scripts is not None:
+ p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
if len(prompts) == 0:
break
@@ -685,6 +689,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
+ if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+ image_mask = p.mask_for_overlay.convert('RGB')
+ image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
+
+ if opts.save_mask:
+ images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+
+ if opts.save_mask_composite:
+ images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+
+ if opts.return_mask:
+ output_images.append(image_mask)
+
+ if opts.return_mask_composite:
+ output_images.append(image_mask_composite)
+
del x_samples_ddim
devices.torch_gc()
@@ -709,7 +729,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks:
+ if not p.disable_extra_networks and extra_network_data:
extra_networks.deactivate(p, extra_network_data)
devices.torch_gc()
@@ -888,7 +908,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ img2img_sampler_name = self.sampler_name
+ if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+ img2img_sampler_name = 'DDIM'
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index edd0e2a7..07911876 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -29,7 +29,7 @@ class ImageSaveParams:
class CFGDenoiserParams:
- def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -44,6 +44,12 @@ class CFGDenoiserParams:
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
+
+ self.text_cond = text_cond
+ """ Encoder hidden states of text conditioning from prompt"""
+
+ self.text_uncond = text_uncond
+ """ Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams:
diff --git a/modules/scripts.py b/modules/scripts.py
index 24056a12..d661be4f 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -33,6 +33,11 @@ class Script:
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
"""
+ paste_field_names = None
+ """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
+ various "Send to <X>" buttons when clicked
+ """
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -80,6 +85,20 @@ class Script:
pass
+ def before_process_batch(self, p, *args, **kwargs):
+ """
+ Called before extra networks are parsed from the prompt, so you can add
+ new extra network keywords to the prompt with this callback.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -220,7 +239,15 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
- for scriptfile in sorted(scripts_list):
+ def orderby(basedir):
+ # 1st webui, 2nd extensions-builtin, 3rd extensions
+ priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
+ for key in priority:
+ if basedir.startswith(key):
+ return priority[key]
+ return 9999
+
+ for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
@@ -256,6 +283,7 @@ class ScriptRunner:
self.alwayson_scripts = []
self.titles = []
self.infotext_fields = []
+ self.paste_field_names = []
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -304,6 +332,9 @@ class ScriptRunner:
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
+ if script.paste_field_names is not None:
+ self.paste_field_names += script.paste_field_names
+
inputs += controls
inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
@@ -388,6 +419,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def before_process_batch(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_process_batch(p, *script_args, **kwargs)
+ except Exception:
+ print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
@@ -481,6 +521,18 @@ def reload_scripts():
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+def add_classes_to_gradio_component(comp):
+ """
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+ """
+
+ comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
+
+ if getattr(comp, 'multiselect', False):
+ comp.elem_classes.append('multiselect')
+
+
+
def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
@@ -489,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
res = original_IOComponent_init(self, *args, **kwargs)
+ add_classes_to_gradio_component(self)
+
script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None:
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index ce0ebb61..b11568c0 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
inputs = []
for script in self.scripts_in_preferred_order():
- with gr.Box() as group:
+ with gr.Row() as group:
self.create_script_ui(script, inputs)
script.group = group
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 79476783..f4bb0266 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -37,11 +37,23 @@ def apply_optimizations():
optimization_method = None
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
+ elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
+ optimization_method = 'sdp-no-mem'
+ elif cmd_opts.opt_sdp_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
+ optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index c02d954c..372555ff 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype
if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
+ q, k, v = q.float(), k.float(), v.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
@@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
+# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
+# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, inner_dim = x.shape
+
+ if mask is not None:
+ mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
+ mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
+
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ head_dim = inner_dim // h
+ q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+
+ del q_in, k_in, v_in
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k, v = q.float(), k.float(), v.float()
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
+ hidden_states = hidden_states.to(dtype)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ return scaled_dot_product_attention_forward(self, x, context, mask)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x):
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+def sdp_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
+ out = out.to(dtype)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
+
+def sdp_no_mem_attnblock_forward(self, x):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ return sdp_attnblock_forward(self, x)
+
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 843ab66c..15858263 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -67,7 +67,7 @@ def hijack_ddpm_edit():
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
-if version.parse(torch.__version__) <= version.parse("1.13.1"):
+if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 127e9663..86218c08 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -105,9 +105,15 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
checkpoint_alisases.clear()
- model_list = modelloader.load_models(model_path=model_path, model_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
cmd_ckpt = shared.cmd_opts.ckpt
+ if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
+ model_url = None
+ else:
+ model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
+
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
+
if os.path.exists(cmd_ckpt):
checkpoint_info = CheckpointInfo(cmd_ckpt)
checkpoint_info.register()
@@ -172,7 +178,7 @@ def select_checkpoint():
return checkpoint_info
-chckpoint_dict_replacements = {
+checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
@@ -180,7 +186,7 @@ chckpoint_dict_replacements = {
def transform_checkpoint_dict_key(k):
- for text, replacement in chckpoint_dict_replacements.items():
+ for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
@@ -204,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
+def read_metadata_from_safetensors(filename):
+ import json
+
+ with open(filename, mode="rb") as file:
+ metadata_len = file.read(8)
+ metadata_len = int.from_bytes(metadata_len, "little")
+ json_start = file.read(2)
+
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
+ json_data = json_start + file.read(metadata_len-2)
+ json_obj = json.loads(json_data)
+
+ res = {}
+ for k, v in json_obj.get("__metadata__", {}).items():
+ res[k] = v
+ if isinstance(v, str) and v[0:1] == '{':
+ try:
+ res[k] = json.loads(v)
+ except Exception as e:
+ pass
+
+ return res
+
+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
@@ -464,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
checkpoints_loaded.clear()
- load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model
try:
@@ -487,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
return sd_model
+
+def unload_model_weights(sd_model=None, info=None):
+ from modules import lowvram, devices, sd_hijack
+ timer = Timer()
+
+ if shared.sd_model:
+
+ # shared.sd_model.cond_stage_model.to(devices.cpu)
+ # shared.sd_model.first_stage_model.to(devices.cpu)
+ shared.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ sd_model = None
+ gc.collect()
+ devices.torch_gc()
+ torch.cuda.empty_cache()
+
+ print(f"Unloaded weights {timer.summary()}.")
+
+ return sd_model \ No newline at end of file
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 28c2136f..ff361f22 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -32,7 +32,7 @@ def set_samplers():
global samplers, samplers_for_img2img
hidden = set(shared.opts.hide_samplers)
- hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
+ hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index d03131cd..083da18c 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -7,19 +7,27 @@ import torch
from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared
+import modules.models.diffusion.uni_pc
samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+ sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
]
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
+ self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
+ self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
+ self.orig_p_sample_ddim = None
+ if self.is_plms:
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms
+ elif self.is_ddim:
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
@@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler:
return self.last_latent
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
+
+ return res
+
+ def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler:
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
- x_dec = img_orig * self.mask + self.nmask * x_dec
+ x = img_orig * self.mask + self.nmask * x
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
@@ -84,12 +101,13 @@ class VanillaStableDiffusionSampler:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+ return x, ts, cond, unconditional_conditioning
+ def update_step(self, last_latent):
if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
+ self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else:
- self.last_latent = res[1]
+ self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent)
@@ -97,26 +115,51 @@ class VanillaStableDiffusionSampler:
state.sampling_step = self.step
shared.total_tqdm.update()
- return res
+ def after_sample(self, x, ts, cond, uncond, res):
+ if not self.is_unipc:
+ self.update_step(res[1])
+
+ return x, ts, cond, uncond, res
+
+ def unipc_after_update(self, x, model_x):
+ self.update_step(x)
def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
+ if self.is_unipc:
+ keys = [
+ ('UniPC variant', 'uni_pc_variant'),
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ('UniPC order', 'uni_pc_order'),
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ]
+
+ for name, key in keys:
+ v = getattr(shared.opts, key)
+ if v != shared.opts.get_default(key):
+ p.extra_generation_params[name] = v
+
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
+ if self.is_unipc:
+ self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
+ if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
+ num_steps = shared.opts.uni_pc_order
valid_step = 999 / (1000 // num_steps)
if valid_step == math.floor(valid_step):
return int(valid_step) + 1
-
+
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 528f513f..93f0e55a 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
+ tensor = denoiser_params.text_cond
+ uncond = denoiser_params.text_uncond
if tensor.shape[1] == uncond.shape[1]:
if not is_edit_model:
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 0027343a..e2f00468 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -35,8 +35,11 @@ def model():
global sd_vae_approx_model
if sd_vae_approx_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
+ sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
diff --git a/modules/shared.py b/modules/shared.py
index 4edcb5ef..73ce77d4 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -69,6 +69,8 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
+parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
@@ -81,6 +83,7 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
@@ -104,15 +107,20 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
-parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
+parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
+parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
+parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
-cmd_opts = parser.parse_args()
+if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
+ cmd_opts = parser.parse_args()
+else:
+ cmd_opts, _ = parser.parse_known_args()
restricted_opts = {
"samples_filename_pattern",
@@ -303,6 +311,7 @@ def list_samplers():
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
+tab_names = []
options_templates = {}
@@ -324,10 +333,14 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
+ "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
+ "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
+ "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
+ "img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
@@ -438,13 +451,16 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (em)"),
- "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (em)"),
+ "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
+ "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
+ "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"),
+ "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
+ "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
@@ -460,6 +476,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
@@ -485,6 +502,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
+ 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
+ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
+ 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
+ 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
}))
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
@@ -559,6 +580,15 @@ class Options:
return True
+ def get_default(self, key):
+ """returns the default value for the key"""
+
+ data_label = self.data_labels.get(key)
+ if data_label is None:
+ return None
+
+ return data_label.default
+
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
@@ -691,6 +721,7 @@ class TotalTQDM:
def clear(self):
if self._tqdm is not None:
+ self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 1568b2b8..af9fbcf2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -115,7 +115,7 @@ class PersonalizedBase(Dataset):
weight /= weight.mean()
elif use_weight:
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
- weight = torch.ones([channels] + latent_size)
+ weight = torch.ones(latent_sample.shape)
else:
weight = None
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c63c7d1d..d2e62e58 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -152,7 +152,11 @@ class EmbeddingDatabase:
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
+ if data:
+ name = data.get('name', name)
+ else:
+ # if data is None, means this is not an embeding, just a preview image
+ return
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
diff --git a/modules/timer.py b/modules/timer.py
index 57a4f17a..ba92be33 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -33,3 +33,6 @@ class Timer:
res += ")"
return res
+
+ def reset(self):
+ self.__init__()
diff --git a/modules/ui.py b/modules/ui.py
index 0516c643..af8546c2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
-from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
+from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -89,7 +89,7 @@ paste_symbol = '\u2199\ufe0f' # ↙
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
-clear_prompt_symbol = '\U0001F5D1' # 🗑️
+clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
@@ -179,14 +179,13 @@ def interrogate_deepbooru(image):
def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row'):
+ with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+ random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
- with gr.Group(elem_id=target_interface + '_subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
@@ -195,8 +194,8 @@ def create_seed_inputs(target_interface):
seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
with FormRow(visible=False) as seed_extra_row_2:
@@ -291,19 +290,19 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
button_interrogate = None
button_deepbooru = None
if is_img2img:
- with gr.Column(scale=1, elem_id="interrogate_col"):
+ with gr.Column(scale=1, elem_classes="interrogate-col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box"):
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+ interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
@@ -325,9 +324,9 @@ def create_toprow(is_img2img):
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
- token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
+ negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
clear_prompt_button.click(
@@ -479,7 +478,9 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
+ with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
+
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
@@ -492,7 +493,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
@@ -586,7 +587,7 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
txt_prompt_img.change(
fn=modules.images.image_data,
@@ -757,7 +758,9 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+ with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
@@ -774,7 +777,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
elif category == "checkboxes":
- with FormRow(elem_id="img2img_checkboxes", variant="compact"):
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
@@ -904,7 +907,7 @@ def create_ui():
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
@@ -939,7 +942,7 @@ def create_ui():
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
@@ -1491,11 +1494,33 @@ def create_ui():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ with gr.Row():
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
with gr.TabItem("Licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+
+ def unload_sd_weights():
+ modules.sd_models.unload_model_weights()
+
+ def reload_sd_weights():
+ modules.sd_models.reload_model_weights()
+
+ unload_sd_model.click(
+ fn=unload_sd_weights,
+ inputs=[],
+ outputs=[]
+ )
+
+ reload_sd_model.click(
+ fn=reload_sd_weights,
+ inputs=[],
+ outputs=[]
+ )
request_notifications.click(
fn=lambda: None,
@@ -1563,6 +1588,10 @@ def create_ui():
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
+ shared.tab_names = []
+ for _interface, label, _ifid in interfaces:
+ shared.tab_names.append(label)
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
@@ -1573,6 +1602,8 @@ def create_ui():
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
+ if label in shared.opts.hidden_tabs:
+ continue
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
@@ -1592,11 +1623,13 @@ def create_ui():
for i, k, item in quicksettings_list:
component = component_dict[k]
+ info = opts.data_labels[k]
component.change(
fn=lambda value, k=k: run_settings_single(value, key=k),
inputs=[component],
outputs=[component, text_settings],
+ show_progress=info.refresh is not None,
)
text_settings.change(
@@ -1745,7 +1778,8 @@ def create_ui():
def reload_javascript():
- head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
+ script_js = os.path.join(script_path, "script.js")
+ head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
@@ -1754,6 +1788,9 @@ def reload_javascript():
for script in modules.scripts.list_scripts("javascript", ".js"):
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
+ for script in modules.scripts.list_scripts("javascript", ".mjs"):
+ head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
+
head += f'<script type="text/javascript">{inline}</script>\n'
def template_response(*args, **kwargs):
diff --git a/modules/ui_common.py b/modules/ui_common.py
index fd047f31..7b752b45 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -129,8 +129,8 @@ Requested path was: {f}
generation_info = None
with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+ with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
+ open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}')
@@ -160,6 +160,7 @@ Requested path was: {f}
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info, html_info],
outputs=[html_info, html_info],
+ show_progress=False,
)
save.click(
@@ -198,9 +199,16 @@ Requested path was: {f}
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+ paste_field_names = []
+ if tabname == "txt2img":
+ paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
+ elif tabname == "img2img":
+ paste_field_names = modules.scripts.scripts_img2img.paste_field_names
+
for paste_tabname, paste_button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
+ paste_field_names=paste_field_names
))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 284ca0cf..2b1da2cb 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -1,55 +1,61 @@
import gradio as gr
-class ToolButton(gr.Button, gr.components.FormComponent):
- """Small button with single emoji as text, fits inside gradio forms"""
+class FormComponent:
+ def get_expected_parent(self):
+ return gr.components.Form
- def __init__(self, **kwargs):
- super().__init__(variant="tool", **kwargs)
- def get_block_name(self):
- return "button"
+gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
-class ToolButtonTop(gr.Button, gr.components.FormComponent):
- """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+class ToolButton(FormComponent, gr.Button):
+ """Small button with single emoji as text, fits inside gradio forms"""
- def __init__(self, **kwargs):
- super().__init__(variant="tool-top", **kwargs)
+ def __init__(self, *args, **kwargs):
+ classes = kwargs.pop("elem_classes", [])
+ super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
def get_block_name(self):
return "button"
-class FormRow(gr.Row, gr.components.FormComponent):
+class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
-class FormGroup(gr.Group, gr.components.FormComponent):
+class FormColumn(FormComponent, gr.Column):
+ """Same as gr.Column but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "column"
+
+
+class FormGroup(FormComponent, gr.Group):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "group"
-class FormHTML(gr.HTML, gr.components.FormComponent):
+class FormHTML(FormComponent, gr.HTML):
"""Same as gr.HTML but fits inside gradio forms"""
def get_block_name(self):
return "html"
-class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
+class FormColorPicker(FormComponent, gr.ColorPicker):
"""Same as gr.ColorPicker but fits inside gradio forms"""
def get_block_name(self):
return "colorpicker"
-class DropdownMulti(gr.Dropdown):
+class DropdownMulti(FormComponent, gr.Dropdown):
"""Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs)
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index bd4308ef..4a502974 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,6 +1,5 @@
import json
import os.path
-import shutil
import sys
import time
import traceback
@@ -141,22 +140,20 @@ def install_extension_from_url(dirname, url):
try:
shutil.rmtree(tmpdir, True)
-
- repo = git.Repo.clone_from(url, tmpdir)
- repo.remote().fetch()
-
+ with git.Repo.clone_from(url, tmpdir) as repo:
+ repo.remote().fetch()
+ for submodule in repo.submodules:
+ submodule.update()
try:
os.rename(tmpdir, target_dir)
except OSError as err:
- # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
- # Shouldn't cause any new issues at least but we probably want to handle it there too.
if err.errno == errno.EXDEV:
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
# Since we can't use a rename, do the slower but more versitile shutil.move()
shutil.move(tmpdir, target_dir)
else:
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
- raise(err)
+ raise err
import launch
launch.run_extension_installer(target_dir)
@@ -244,7 +241,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column):
hidden += 1
continue
- install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
+ install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
@@ -304,7 +301,7 @@ def create_ui():
with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 0c7ba173..08a69930 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -22,21 +22,37 @@ def register_page(page):
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
-def add_pages_to_demo(app):
- def fetch_file(filename: str = ""):
- from starlette.responses import FileResponse
+def fetch_file(filename: str = ""):
+ from starlette.responses import FileResponse
+
+ if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+
+ ext = os.path.splitext(filename)[1].lower()
+ if ext not in (".png", ".jpg", ".webp"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+
+ # would profit from returning 304
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+
- if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
- raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+def get_metadata(page: str = "", item: str = ""):
+ from starlette.responses import JSONResponse
- ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
+ if page is None:
+ return JSONResponse({})
- # would profit from returning 304
- return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+ metadata = page.metadata.get(item)
+ if metadata is None:
+ return JSONResponse({})
+ return JSONResponse({"metadata": metadata})
+
+
+def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
class ExtraNetworksPage:
@@ -45,6 +61,7 @@ class ExtraNetworksPage:
self.name = title.lower()
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
+ self.metadata = {}
def refresh(self):
pass
@@ -66,6 +83,8 @@ class ExtraNetworksPage:
view = shared.opts.extra_networks_default_view
items_html = ''
+ self.metadata = {}
+
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
@@ -86,12 +105,16 @@ class ExtraNetworksPage:
subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f"""
-<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
+<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
{html.escape(subdir if subdir!="" else "all")}
</button>
""" for subdir in subdirs])
for item in self.list_items():
+ metadata = item.get("metadata")
+ if metadata:
+ self.metadata[item["name"]] = metadata
+
items_html += self.create_html_for_item(item, tabname)
if items_html == '':
@@ -124,9 +147,13 @@ class ExtraNetworksPage:
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
- height = f"height: {shared.opts.extra_networks_card_height}em;" if shared.opts.extra_networks_card_height else ''
- width = f"width: {shared.opts.extra_networks_card_width}em;" if shared.opts.extra_networks_card_width else ''
+ height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
+ width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
+ metadata_button = ""
+ metadata = item.get("metadata")
+ if metadata:
+ metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata({json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
args = {
"style": f"'{height}{width}{background_image}'",
@@ -134,13 +161,44 @@ class ExtraNetworksPage:
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
+ "description": (item.get("description") or ""),
"card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
+ "metadata_button": metadata_button,
}
return self.card_page.format(**args)
+ def find_preview(self, path):
+ """
+ Find a preview PNG for a given path (without extension) and call link_preview on it.
+ """
+
+ preview_extensions = ["png", "jpg", "webp"]
+ if shared.opts.samples_format not in preview_extensions:
+ preview_extensions.append(shared.opts.samples_format)
+
+ potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
+
+ for file in potential_files:
+ if os.path.isfile(file):
+ return self.link_preview(file)
+
+ return None
+
+ def find_description(self, path):
+ """
+ Find and read a description file for a given path (without extension).
+ """
+ for file in [f"{path}.txt", f"{path}.description.txt"]:
+ try:
+ with open(file, "r", encoding="utf-8", errors="replace") as f:
+ return f.read()
+ except OSError:
+ pass
+ return None
+
def intialize():
extra_pages.clear()
@@ -182,12 +240,12 @@ def create_ui(container, button, tabname):
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages:
with gr.Tab(page.title):
+
page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem)
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
- button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
@@ -198,7 +256,6 @@ def create_ui(container, button, tabname):
state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
- button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh():
res = []
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 04097a79..a17aa9c9 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -1,7 +1,6 @@
import html
import json
import os
-import urllib.parse
from modules import shared, ui_extra_networks, sd_models
@@ -17,21 +16,14 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items():
path, ext = os.path.splitext(checkpoint.filename)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
-
yield {
"name": checkpoint.name_for_extra,
"filename": path,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
- "local_preview": path + ".png",
+ "local_preview": f"{path}.{shared.opts.samples_format}",
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 57851088..6187e000 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -14,21 +14,15 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def list_items(self):
for name, path in shared.hypernetworks.items():
path, ext = os.path.splitext(path)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
yield {
"name": name,
"filename": path,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": path + ".png",
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index bb64eb81..6944d559 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,7 +1,7 @@
import json
import os
-from modules import ui_extra_networks, sd_hijack
+from modules import ui_extra_networks, sd_hijack, shared
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
@@ -15,19 +15,14 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def list_items(self):
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
path, ext = os.path.splitext(embedding.filename)
- preview_file = path + ".preview.png"
-
- preview = None
- if os.path.isfile(preview_file):
- preview = self.link_preview(preview_file)
-
yield {
"name": embedding.name,
"filename": embedding.filename,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name),
- "local_preview": path + ".preview.png",
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
}
def allowed_directories_for_previews(self):