aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py2
-rw-r--r--modules/api/models.py3
-rw-r--r--modules/call_queue.py19
-rw-r--r--modules/devices.py31
-rw-r--r--modules/errors.py2
-rw-r--r--modules/extras.py150
-rw-r--r--modules/generation_parameters_copypaste.py6
-rw-r--r--modules/hashes.py85
-rw-r--r--modules/hypernetworks/hypernetwork.py73
-rw-r--r--modules/images.py5
-rw-r--r--modules/img2img.py4
-rw-r--r--modules/processing.py25
-rw-r--r--modules/progress.py99
-rw-r--r--modules/prompt_parser.py7
-rw-r--r--modules/realesrgan_model.py12
-rw-r--r--modules/script_callbacks.py20
-rw-r--r--modules/sd_hijack.py8
-rw-r--r--modules/sd_hijack_checkpoint.py38
-rw-r--r--modules/sd_hijack_clip.py17
-rw-r--r--modules/sd_models.py121
-rw-r--r--modules/sd_samplers.py19
-rw-r--r--modules/sd_vae.py212
-rw-r--r--modules/sd_vae_approx.py2
-rw-r--r--modules/shared.py53
-rw-r--r--modules/styles.py12
-rw-r--r--modules/textual_inversion/dataset.py52
-rw-r--r--modules/textual_inversion/image_embedding.py4
-rw-r--r--modules/textual_inversion/logging.py2
-rw-r--r--modules/textual_inversion/preprocess.py38
-rw-r--r--modules/textual_inversion/textual_inversion.py61
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py346
-rw-r--r--modules/ui_progress.py101
-rw-r--r--modules/upscaler.py1
34 files changed, 1095 insertions, 539 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 5767ba90..9814bbc2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -371,7 +371,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/api/models.py b/modules/api/models.py
index c78095ca..1eb1fcf1 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -224,7 +224,8 @@ class UpscalerItem(BaseModel):
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
- hash: str = Field(title="Hash")
+ hash: Optional[str] = Field(title="Short hash")
+ sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 4cd49533..92097c15 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -4,7 +4,7 @@ import threading
import traceback
import time
-from modules import shared
+from modules import shared, progress
queue_lock = threading.Lock()
@@ -22,12 +22,23 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
- shared.state.begin()
+ # if the first argument is a string that says "task(...)", it is treated as a job id
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ id_task = args[0]
+ progress.add_task_to_queue(id_task)
+ else:
+ id_task = None
with queue_lock:
- res = func(*args, **kwargs)
+ shared.state.begin()
+ progress.start_task(id_task)
+
+ try:
+ res = func(*args, **kwargs)
+ finally:
+ progress.finish_task(id_task)
- shared.state.end()
+ shared.state.end()
return res
diff --git a/modules/devices.py b/modules/devices.py
index ac3ae0c9..524ec7af 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -106,6 +106,36 @@ def autocast(disable=False):
return torch.autocast("cuda")
+class NansException(Exception):
+ pass
+
+
+def test_for_nans(x, where):
+ from modules import shared
+
+ if shared.cmd_opts.disable_nan_check:
+ return
+
+ if not torch.all(torch.isnan(x)).item():
+ return
+
+ if where == "unet":
+ message = "A tensor with all NaNs was produced in Unet."
+
+ if not shared.cmd_opts.no_half:
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
+
+ elif where == "vae":
+ message = "A tensor with all NaNs was produced in VAE."
+
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
+ else:
+ message = "A tensor with all NaNs was produced."
+
+ raise NansException(message)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
@@ -159,3 +189,4 @@ if has_mps():
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
+
diff --git a/modules/errors.py b/modules/errors.py
index a668c014..a10e8708 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -19,7 +19,7 @@ def display(e: Exception, task):
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
-The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
diff --git a/modules/extras.py b/modules/extras.py
index a03d558e..d03f976e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
-from modules import processing, shared, images, devices, sd_models, sd_samplers
+from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -251,7 +251,8 @@ def run_pnginfo(image):
def create_config(ckpt_result, config_source, a, b, c):
def config(x):
- return sd_models.find_checkpoint_config(x) if x else None
+ res = sd_models.find_checkpoint_config(x) if x else None
+ return res if res != shared.sd_default_config else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
@@ -274,10 +275,25 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename)
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
+chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
+
+def to_half(tensor, enable):
+ if enable and tensor.dtype == torch.float:
+ return tensor.half()
+
+ return tensor
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin()
shared.state.job = 'model-merge'
+ def fail(message):
+ shared.state.textinfo = message
+ shared.state.end()
+ return [*[gr.update() for _ in range(4)], message]
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -287,51 +303,96 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
- primary_model_info = sd_models.checkpoints_list[primary_model_name]
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
- result_is_inpainting_model = False
+ def filename_weighed_sum():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ Ma = round(1 - multiplier, 2)
+ Mb = round(multiplier, 2)
+
+ return f"{Ma}({a}) + {Mb}({b})"
+
+ def filename_add_differnece():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ c = tertiary_model_info.model_name
+ M = round(multiplier, 2)
+
+ return f"{a} + {M}({b} - {c})"
+
+ def filename_nothing():
+ return primary_model_info.model_name
theta_funcs = {
- "Weighted sum": (None, weighted_sum),
- "Add difference": (get_difference, add_difference),
+ "Weighted sum": (filename_weighed_sum, None, weighted_sum),
+ "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "No interpolation": (filename_nothing, None, None),
}
- theta_func1, theta_func2 = theta_funcs[interp_method]
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
- if theta_func1 and not tertiary_model_info:
- shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
- shared.state.end()
- return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ if not primary_model_name:
+ return fail("Failed: Merging requires a primary model.")
+
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
+
+ if theta_func2 and not secondary_model_name:
+ return fail("Failed: Merging requires a secondary model.")
+
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
+
+ if theta_func1 and not tertiary_model_name:
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
- shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
+
+ result_is_inpainting_model = False
+
+ if theta_func2:
+ shared.state.textinfo = f"Loading B"
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ else:
+ theta_1 = None
if theta_func1:
+ shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+ shared.state.textinfo = 'Merging B and C'
+ shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
+
+ shared.state.sampling_step += 1
del theta_2
+ shared.state.nextjob()
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
-
+ shared.state.textinfo = 'Merging A and B'
+ shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
- if 'model' in key and key in theta_1:
+ if theta_1 and 'model' in key and key in theta_1:
+
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
a = theta_0[key]
b = theta_1[key]
- shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -346,32 +407,39 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
else:
theta_0[key] = theta_func2(a, b, multiplier)
- if save_as_half:
- theta_0[key] = theta_0[key].half()
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
+ shared.state.sampling_step += 1
- # I believe this part should be discarded, but I'll leave it for now until I am sure
- for key in theta_1.keys():
- if 'model' in key and key not in theta_0:
- theta_0[key] = theta_1[key]
- if save_as_half:
- theta_0[key] = theta_0[key].half()
del theta_1
- ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
+ if bake_in_vae_filename is not None:
+ print(f"Baking in VAE from {bake_in_vae_filename}")
+ shared.state.textinfo = 'Baking in VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
- filename = \
- primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
- secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
- interp_method.replace(" ", "_") + \
- '-merged.' + \
- ("inpainting." if result_is_inpainting_model else "") + \
- checkpoint_format
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
+
+ del vae_dict
+
+ if save_as_half and not theta_func2:
+ for key in theta_0.keys():
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+ filename = filename_generator() if custom_name == '' else custom_name
+ filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
- shared.state.textinfo = f"Saving to {output_modelname}..."
+ shared.state.nextjob()
+ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -384,8 +452,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
- print("Checkpoint saved.")
- shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ print(f"Checkpoint saved to {output_modelname}.")
+ shared.state.textinfo = "Checkpoint saved"
shared.state.end()
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 620aa606..a381ff59 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -7,7 +7,7 @@ from pathlib import Path
import gradio as gr
from modules.shared import script_path
-from modules import shared, ui_tempdir
+from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image
@@ -37,6 +37,9 @@ def quote(text):
def image_from_url_text(filedata):
+ if filedata is None:
+ return None
+
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
@@ -298,6 +301,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
prompt = file.read()
params = parse_generation_parameters(prompt)
+ script_callbacks.infotext_pasted_callback(prompt, params)
res = []
for output, key in paste_fields:
diff --git a/modules/hashes.py b/modules/hashes.py
new file mode 100644
index 00000000..b85a7580
--- /dev/null
+++ b/modules/hashes.py
@@ -0,0 +1,85 @@
+import hashlib
+import json
+import os.path
+
+import filelock
+
+
+cache_filename = "cache.json"
+cache_data = None
+
+
+def dump_cache():
+ with filelock.FileLock(cache_filename+".lock"):
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ global cache_data
+
+ if cache_data is None:
+ with filelock.FileLock(cache_filename+".lock"):
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def calculate_sha256(filename):
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def sha256_from_cache(filename, title):
+ hashes = cache("hashes")
+ ondisk_mtime = os.path.getmtime(filename)
+
+ if title not in hashes:
+ return None
+
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
+ return None
+
+ return cached_sha256
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
+
+ sha256_value = sha256_from_cache(filename, title)
+ if sha256_value is not None:
+ return sha256_value
+
+ print(f"Calculating sha256 for {filename}: ", end='')
+ sha256_value = calculate_sha256(filename)
+ print(f"{sha256_value}")
+
+ hashes[title] = {
+ "mtime": os.path.getmtime(filename),
+ "sha256": sha256_value,
+ }
+
+ dump_cache()
+
+ return sha256_value
+
+
+
+
+
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 194679e8..74e78582 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -24,7 +24,6 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
-
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -226,7 +225,7 @@ class Hypernetwork:
torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['hash'] = self.shorthash()
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
@@ -238,32 +237,33 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- print(self.layer_structure)
- optional_info = state_dict.get('optional_info', None)
- if optional_info is not None:
- print(f"INFO:\n {optional_info}\n")
- self.optional_info = optional_info
+ self.optional_info = state_dict.get('optional_info', None)
self.activation_func = state_dict.get('activation_func', None)
- print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
- print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
- print(f"Layer norm is set to {self.add_layer_norm}")
self.dropout_structure = state_dict.get('dropout_structure', None)
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
- print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
- print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
if self.dropout_structure is None:
- print("Using previous dropout structure")
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- print(f"Dropout structure is set to {self.dropout_structure}")
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ if shared.opts.print_hypernet_extra:
+ if self.optional_info is not None:
+ print(f" INFO:\n {self.optional_info}\n")
+
+ print(f" Layer structure: {self.layer_structure}")
+ print(f" Activation function: {self.activation_func}")
+ print(f" Weight initialization: {self.weight_init}")
+ print(f" Layer norm: {self.add_layer_norm}")
+ print(f" Dropout usage: {self.use_dropout}" )
+ print(f" Activate last layer: {self.activate_output}")
+ print(f" Dropout structure: {self.dropout_structure}")
- if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
+
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
@@ -290,6 +290,11 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
self.eval()
+ def shorthash(self):
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
+
+ return sha256[0:10]
+
def list_hypernetworks(path):
res = {}
@@ -297,7 +302,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name + f"({sd_models.model_hash(filename)})"] = filename
+ res[name] = filename
return res
@@ -448,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -498,6 +503,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
+
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -507,7 +515,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
)
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
@@ -553,6 +561,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
@@ -566,6 +575,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -602,7 +613,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
-
+ loss_logging.append(_loss_step)
if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate)
@@ -621,7 +632,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -632,6 +642,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
+
+ if shared.opts.training_enable_tensorboard:
+ epoch_num = hypernetwork.step // len(ds)
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
+ mean_loss = sum(loss_logging) / len(loss_logging)
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
@@ -682,7 +700,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
torch.cuda.set_rng_state_all(cuda_rng_state)
hypernetwork.train()
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
+ f"Validation at epoch {epoch_num}", image,
+ hypernetwork.step)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
@@ -704,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
hypernetwork.eval()
#report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
@@ -724,7 +749,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
- hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.name = hypernetwork_name
hypernetwork.save(filename)
diff --git a/modules/images.py b/modules/images.py
index c3a5fc8b..3b1c5f34 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -605,8 +605,9 @@ def read_info_from_image(image):
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
- items['exif comment'] = exif_comment
- geninfo = exif_comment
+ if exif_comment:
+ items['exif comment'] = exif_comment
+ geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
diff --git a/modules/img2img.py b/modules/img2img.py
index f62783c6..2168c8e2 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
@@ -101,7 +101,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
diff --git a/modules/processing.py b/modules/processing.py
index f04a0e1e..a3e9f709 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return image_conditioning
-class StableDiffusionProcessing():
+class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
@@ -102,7 +102,6 @@ class StableDiffusionProcessing():
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
- self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
@@ -156,6 +155,10 @@ class StableDiffusionProcessing():
self.all_subseeds = None
self.iteration = 0
+ @property
+ def sd_model(self):
+ return shared.sd_model
+
def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@@ -236,7 +239,6 @@ class StableDiffusionProcessing():
raise NotImplementedError()
def close(self):
- self.sd_model = None
self.sampler = None
@@ -437,7 +439,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
@@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights() # make onchange call for changing SD model
- p.sd_model = shared.sd_model
if k == 'sd_vae':
sd_vae.reload_vae_weights() # make onchange call for changing VAE
@@ -531,10 +532,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -571,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
if state.job_count == -1:
state.job_count = p.n_iter
@@ -608,6 +609,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+ for x in x_samples_ddim:
+ devices.test_for_nans(x, "vae")
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -853,7 +857,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
diff --git a/modules/progress.py b/modules/progress.py
new file mode 100644
index 00000000..c69ecf3d
--- /dev/null
+++ b/modules/progress.py
@@ -0,0 +1,99 @@
+import base64
+import io
+import time
+
+import gradio as gr
+from pydantic import BaseModel, Field
+
+from modules.shared import opts
+
+import modules.shared as shared
+
+
+current_task = None
+pending_tasks = {}
+finished_tasks = []
+
+
+def start_task(id_task):
+ global current_task
+
+ current_task = id_task
+ pending_tasks.pop(id_task, None)
+
+
+def finish_task(id_task):
+ global current_task
+
+ if current_task == id_task:
+ current_task = None
+
+ finished_tasks.append(id_task)
+ if len(finished_tasks) > 16:
+ finished_tasks.pop(0)
+
+
+def add_task_to_queue(id_job):
+ pending_tasks[id_job] = time.time()
+
+
+class ProgressRequest(BaseModel):
+ id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
+ id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+
+
+class ProgressResponse(BaseModel):
+ active: bool = Field(title="Whether the task is being worked on right now")
+ queued: bool = Field(title="Whether the task is in queue")
+ completed: bool = Field(title="Whether the task has already finished")
+ progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
+ eta: float = Field(default=None, title="ETA in secs")
+ live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
+ id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+
+
+def setup_progress_api(app):
+ return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+
+
+def progressapi(req: ProgressRequest):
+ active = req.id_task == current_task
+ queued = req.id_task in pending_tasks
+ completed = req.id_task in finished_tasks
+
+ if not active:
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
+
+ progress = 0
+
+ job_count, job_no = shared.state.job_count, shared.state.job_no
+ sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
+
+ if job_count > 0:
+ progress += job_no / job_count
+ if sampling_steps > 0 and job_count > 0:
+ progress += 1 / job_count * sampling_step / sampling_steps
+
+ progress = min(progress, 1)
+
+ elapsed_since_start = time.time() - shared.state.time_start
+ predicted_duration = elapsed_since_start / progress if progress > 0 else None
+ eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
+
+ id_live_preview = req.id_live_preview
+ shared.state.set_current_image()
+ if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
+ image = shared.state.current_image
+ if image is not None:
+ buffered = io.BytesIO()
+ image.save(buffered, format="png")
+ live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
+ id_live_preview = shared.state.id_live_preview
+ else:
+ live_preview = None
+ else:
+ live_preview = None
+
+ return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 870218db..69665372 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -274,6 +274,7 @@ re_attention = re.compile(r"""
:
""", re.X)
+re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text):
"""
@@ -339,7 +340,11 @@ def parse_prompt_attention(text):
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
- res.append([text, 1.0])
+ parts = re.split(re_break, text)
+ for i, part in enumerate(parts):
+ if i > 0:
+ res.append(["BREAK", -1])
+ res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 3ac0b97a..47f70251 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -38,13 +38,13 @@ class UpscalerRealESRGAN(Upscaler):
return img
info = self.load_model(path)
- if not os.path.exists(info.data_path):
+ if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
- model_path=info.data_path,
+ model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
@@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler):
def load_model(self, path):
try:
- info = None
- for scaler in self.scalers:
- if scaler.data_path == path:
- info = scaler
+ info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
if info is None:
print(f"Unable to find model info: {path}")
return None
- model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
- info.data_path = model_file
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 608c5300..a9e19236 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -2,7 +2,7 @@ import sys
import traceback
from collections import namedtuple
import inspect
-from typing import Optional
+from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
@@ -71,6 +71,7 @@ callback_map = dict(
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
+ callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
)
@@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid')
+def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+ for c in callback_map['callbacks_infotext_pasted']:
+ try:
+ c.callback(infotext, params)
+ except Exception:
+ report_exception(c, 'infotext_pasted')
+
+
def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']):
try:
@@ -290,6 +299,15 @@ def on_image_grid(callback):
add_callback(callback_map['callbacks_image_grid'], callback)
+def on_infotext_pasted(callback):
+ """register a function to be called before applying an infotext.
+ The callback is called with two arguments:
+ - infotext: str - raw infotext.
+ - result: Dict[str, any] - parsed infotext parameters.
+ """
+ add_callback(callback_map['callbacks_infotext_pasted'], callback)
+
+
def on_script_unloaded(callback):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here"""
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6b0d95af..870eba88 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -69,12 +69,6 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-def fix_checkpoint():
- ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
-
-
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -106,8 +100,6 @@ class StableDiffusionModelHijack:
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
-
- fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
index 5712972f..2604d969 100644
--- a/modules/sd_hijack_checkpoint.py
+++ b/modules/sd_hijack_checkpoint.py
@@ -1,10 +1,46 @@
from torch.utils.checkpoint import checkpoint
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.openaimodel
+
+
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
+
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
+
def ResBlock_forward(self, x, emb):
- return checkpoint(self._forward, x, emb) \ No newline at end of file
+ return checkpoint(self._forward, x, emb)
+
+
+stored = []
+
+
+def add():
+ if len(stored) != 0:
+ return
+
+ stored.extend([
+ ldm.modules.attention.BasicTransformerBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
+ ])
+
+ ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
+
+
+def remove():
+ if len(stored) == 0:
+ return
+
+ ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
+
+ stored.clear()
+
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 852afc66..9fa5c5c5 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
token_count = 0
last_comma = -1
- def next_chunk():
- """puts current chunk into the list of results and produces the next one - empty"""
+ def next_chunk(is_last=False):
+ """puts current chunk into the list of results and produces the next one - empty;
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
- token_count += len(chunk.tokens)
+ if is_last:
+ token_count += len(chunk.tokens)
+ else:
+ token_count += self.chunk_length
+
to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [self.id_end] * to_add
@@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
+ if text == 'BREAK' and weight == -1:
+ next_chunk()
+ continue
+
position = 0
while position < len(tokens):
token = tokens[position]
@@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0:
- next_chunk()
+ next_chunk(is_last=True)
return chunks, token_count
diff --git a/modules/sd_models.py b/modules/sd_models.py
index c466f273..6a681cef 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,17 +14,58 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
+checkpoint_alisases = {}
checkpoints_loaded = collections.OrderedDict()
+
+class CheckpointInfo:
+ def __init__(self, filename):
+ self.filename = filename
+ abspath = os.path.abspath(filename)
+
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
+ else:
+ name = os.path.basename(filename)
+
+ if name.startswith("\\") or name.startswith("/"):
+ name = name[1:]
+
+ self.title = name
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+ self.hash = model_hash(filename)
+
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
+
+ self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
+
+ def register(self):
+ checkpoints_list[self.title] = self
+ for id in self.ids:
+ checkpoint_alisases[id] = self
+
+ def calculate_shorthash(self):
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10]
+
+ if self.shorthash not in self.ids:
+ self.ids += [self.shorthash, self.sha256]
+ self.register()
+
+ return self.shorthash
+
+
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -43,10 +84,14 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- convert = lambda name: int(name) if name.isdigit() else name.lower()
- alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
- return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def checkpoint_tiles():
+ def convert(name):
+ return int(name) if name.isdigit() else name.lower()
+
+ def alphanumeric_key(key):
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
@@ -62,48 +107,38 @@ def find_checkpoint_config(info):
def list_models():
checkpoints_list.clear()
+ checkpoint_alisases.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
- def modeltitle(path, shorthash):
- abspath = os.path.abspath(path)
-
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
- elif abspath.startswith(model_path):
- name = abspath.replace(model_path, '')
- else:
- name = os.path.basename(path)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
-
- return f'{name} [{shorthash}]', shortname
-
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
- h = model_hash(cmd_ckpt)
- title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
- shared.opts.data['sd_model_checkpoint'] = title
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
+ checkpoint_info.register()
+
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
+
for filename in model_list:
- h = model_hash(filename)
- title, short_model_name = modeltitle(filename, h)
+ checkpoint_info = CheckpointInfo(filename)
+ checkpoint_info.register()
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+def get_closet_checkpoint_match(search_string):
+ checkpoint_info = checkpoint_alisases.get(search_string, None)
+ if checkpoint_info is not None:
+ return checkpoint_info
+
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
-def get_closet_checkpoint_match(searchString):
- applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
- if len(applicable) > 0:
- return applicable[0]
return None
def model_hash(filename):
+ """old hash that only looks at a small part of the file and is prone to collisions"""
+
try:
with open(filename, "rb") as file:
import hashlib
@@ -119,7 +154,7 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoints_list.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -189,9 +224,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
-def load_model_weights(model, checkpoint_info, vae_file="auto"):
- checkpoint_file = checkpoint_info.filename
- sd_model_hash = checkpoint_info.hash
+def load_model_weights(model, checkpoint_info: CheckpointInfo):
+ sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -201,9 +235,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
- sd = read_state_dict(checkpoint_file)
+ sd = read_state_dict(checkpoint_info.filename)
model.load_state_dict(sd, strict=False)
del sd
@@ -235,15 +269,16 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
- model.sd_model_checkpoint = checkpoint_file
+ model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
- sd_vae.load_vae(model, vae_file)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ sd_vae.load_vae(model, vae_file, vae_source)
def enable_midas_autodownload():
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 01221b89..6261d1f7 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -138,9 +138,9 @@ def samples_to_image_grid(samples, approximation=None):
def store_latent(decoded):
state.current_latent = decoded
- if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
- shared.state.current_image = sample_to_image(decoded)
+ shared.state.assign_current_image(sample_to_image(decoded))
class InterruptedException(BaseException):
@@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler:
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
@@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler:
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
-
+
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -352,6 +351,13 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ devices.test_for_nans(x_out, "unet")
+
+ if opts.live_preview_content == "Prompt":
+ store_latent(x_out[0:uncond.shape[0]])
+ elif opts.live_preview_content == "Negative prompt":
+ store_latent(x_out[-uncond.shape[0]:])
+
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
@@ -423,7 +429,8 @@ class KDiffusionSampler:
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
- store_latent(latent)
+ if opts.live_preview_content == "Combined":
+ store_latent(latent)
self.last_latent = latent
if self.stop_at is not None and step > self.stop_at:
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 0a49daa1..4ce238b8 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -9,23 +9,9 @@ import glob
from copy import deepcopy
-model_dir = "Stable-diffusion"
-model_path = os.path.abspath(os.path.join(models_path, model_dir))
-vae_dir = "VAE"
-vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
-
-
+vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-
-
-default_vae_dict = {"auto": "auto", "None": None, None: None}
-default_vae_list = ["auto", "None"]
-
-
-default_vae_values = [default_vae_dict[x] for x in default_vae_list]
-vae_dict = dict(default_vae_dict)
-vae_list = list(default_vae_list)
-first_load = True
+vae_dict = {}
base_vae = None
@@ -64,100 +50,84 @@ def restore_base_vae(model):
def get_filename(filepath):
- return os.path.splitext(os.path.basename(filepath))[0]
-
-
-def refresh_vae_list(vae_path=vae_path, model_path=model_path):
- global vae_dict, vae_list
- res = {}
- candidates = [
- *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
- *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
- *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
+ return os.path.basename(filepath)
+
+
+def refresh_vae_list():
+ vae_dict.clear()
+
+ paths = [
+ os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
+ os.path.join(sd_models.model_path, '**/*.vae.pt'),
+ os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
+ os.path.join(vae_path, '**/*.ckpt'),
+ os.path.join(vae_path, '**/*.pt'),
+ os.path.join(vae_path, '**/*.safetensors'),
]
- if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
- candidates.append(shared.cmd_opts.vae_path)
+
+ if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
+ paths += [
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
+ ]
+
+ if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
+ paths += [
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
+ ]
+
+ candidates = []
+ for path in paths:
+ candidates += glob.iglob(path, recursive=True)
+
for filepath in candidates:
name = get_filename(filepath)
- res[name] = filepath
- vae_list.clear()
- vae_list.extend(default_vae_list)
- vae_list.extend(list(res.keys()))
- vae_dict.clear()
- vae_dict.update(res)
- vae_dict.update(default_vae_dict)
- return vae_list
-
-
-def get_vae_from_settings(vae_file="auto"):
- # else, we load from settings, if not set to be default
- if vae_file == "auto" and shared.opts.sd_vae is not None:
- # if saved VAE settings isn't recognized, fallback to auto
- vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
- # if VAE selected but not found, fallback to auto
- if vae_file not in default_vae_values and not os.path.isfile(vae_file):
- vae_file = "auto"
- print(f"Selected VAE doesn't exist: {vae_file}")
- return vae_file
-
-
-def resolve_vae(checkpoint_file=None, vae_file="auto"):
- global first_load, vae_dict, vae_list
-
- # if vae_file argument is provided, it takes priority, but not saved
- if vae_file and vae_file not in default_vae_list:
- if not os.path.isfile(vae_file):
- print(f"VAE provided as function argument doesn't exist: {vae_file}")
- vae_file = "auto"
- # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
- if first_load and shared.cmd_opts.vae_path is not None:
- if os.path.isfile(shared.cmd_opts.vae_path):
- vae_file = shared.cmd_opts.vae_path
- shared.opts.data['sd_vae'] = get_filename(vae_file)
- else:
- print(f"VAE provided as command line argument doesn't exist: {vae_file}")
- # fallback to selector in settings, if vae selector not set to act as default fallback
- if not shared.opts.sd_vae_as_default:
- vae_file = get_vae_from_settings(vae_file)
- # vae-path cmd arg takes priority for auto
- if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
- if os.path.isfile(shared.cmd_opts.vae_path):
- vae_file = shared.cmd_opts.vae_path
- print(f"Using VAE provided as command line argument: {vae_file}")
- # if still not found, try look for ".vae.pt" beside model
- model_path = os.path.splitext(checkpoint_file)[0]
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.pt"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # if still not found, try look for ".vae.ckpt" beside model
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.ckpt"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # if still not found, try look for ".vae.safetensors" beside model
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.safetensors"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # No more fallbacks for auto
- if vae_file == "auto":
- vae_file = None
- # Last check, just because
- if vae_file and not os.path.exists(vae_file):
- vae_file = None
-
- return vae_file
-
-
-def load_vae(model, vae_file=None):
- global first_load, vae_dict, vae_list, loaded_vae_file
+ vae_dict[name] = filepath
+
+
+def find_vae_near_checkpoint(checkpoint_file):
+ checkpoint_path = os.path.splitext(checkpoint_file)[0]
+ for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
+ if os.path.isfile(vae_location):
+ return vae_location
+
+ return None
+
+
+def resolve_vae(checkpoint_file):
+ if shared.cmd_opts.vae_path is not None:
+ return shared.cmd_opts.vae_path, 'from commandline argument'
+
+ is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
+
+ vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
+ if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
+ return vae_near_checkpoint, 'found near the checkpoint'
+
+ if shared.opts.sd_vae == "None":
+ return None, None
+
+ vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
+ if vae_from_options is not None:
+ return vae_from_options, 'specified in settings'
+
+ if not is_automatic:
+ print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
+
+ return None, None
+
+
+def load_vae_dict(filename, map_location):
+ vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ return vae_dict_1
+
+
+def load_vae(model, vae_file=None, vae_source="from unknown source"):
+ global vae_dict, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@@ -165,16 +135,15 @@ def load_vae(model, vae_file=None):
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
- print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
+ print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
else:
- assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
- print(f"Loading VAE weights from: {vae_file}")
+ assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
+ print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
- vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
@@ -191,14 +160,12 @@ def load_vae(model, vae_file=None):
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
- vae_list.append(vae_opt)
+
elif loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file
- first_load = False
-
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
@@ -211,7 +178,10 @@ def clear_loaded_vae():
loaded_vae_file = None
-def reload_vae_weights(sd_model=None, vae_file="auto"):
+unspecified = object()
+
+
+def reload_vae_weights(sd_model=None, vae_file=unspecified):
from modules import lowvram, devices, sd_hijack
if not sd_model:
@@ -219,7 +189,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
- vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
+
+ if vae_file == unspecified:
+ vae_file, vae_source = resolve_vae(checkpoint_file)
+ else:
+ vae_source = "from function argument"
if loaded_vae_file == vae_file:
return
@@ -231,7 +205,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_vae(sd_model, vae_file)
+ load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
@@ -239,5 +213,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
- print("VAE Weights loaded.")
+ print("VAE weights loaded.")
return sd_model
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 0a58542d..0027343a 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -36,7 +36,7 @@ def model():
if sd_vae_approx_model is None:
sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+ sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
diff --git a/modules/shared.py b/modules/shared.py
index 1c964237..29b28bff 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -20,12 +20,14 @@ from modules.paths import models_path, script_path, sd_path
demo = None
+sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
+parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@@ -64,6 +66,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
@@ -83,7 +86,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
+parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -116,6 +119,7 @@ restricted_opts = {
}
ui_reorder_categories = [
+ "inpaint",
"sampler",
"dimensions",
"cfg",
@@ -152,6 +156,7 @@ def reload_hypernetworks():
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
class State:
skipped = False
interrupted = False
@@ -165,9 +170,11 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
+ id_live_preview = 0
textinfo = None
time_start = None
need_restart = False
+ server_start = None
def skip(self):
self.skipped = True
@@ -176,7 +183,7 @@ class State:
self.interrupted = True
def nextjob(self):
- if opts.show_progress_every_n_steps == -1:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
@@ -206,6 +213,7 @@ class State:
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
+ self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
@@ -219,12 +227,12 @@ class State:
devices.torch_gc()
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not parallel_processing_allowed:
return
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
@@ -233,14 +241,19 @@ class State:
import modules.sd_samplers
if opts.show_progress_grid:
- self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
- self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
+ def assign_current_image(self, image):
+ self.current_image = image
+ self.id_live_preview += 1
+
state = State()
+state.server_start = time.time()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
@@ -358,9 +371,11 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
}))
options_templates.update(options_section(('system', "System"), {
+ "show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
+ "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -373,14 +388,17 @@ options_templates.update(options_section(('training', "Training"), {
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
+ "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
+ "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
+ "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
- "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -418,14 +436,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
}))
options_templates.update(options_section(('ui', "User interface"), {
- "show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
- "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
@@ -440,6 +454,16 @@ options_templates.update(options_section(('ui', "User interface"), {
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
+options_templates.update(options_section(('ui', "Live previews"), {
+ "show_progressbar": OptionInfo(True, "Show progressbar"),
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
+ "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
+ "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
+}))
+
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -454,6 +478,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
options_templates.update()
diff --git a/modules/styles.py b/modules/styles.py
index ce6e71ca..990d5623 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -40,12 +40,18 @@ def apply_styles_to_prompt(prompt, styles):
class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
- self.styles = {"None": self.no_style}
+ self.styles = {}
+ self.path = path
- if not os.path.exists(path):
+ self.reload()
+
+ def reload(self):
+ self.styles.clear()
+
+ if not os.path.exists(self.path):
return
- with open(path, "r", encoding="utf-8-sig", newline='') as file:
+ with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file)
for row in reader:
# Support loading old CSV format with "name, text"-columns
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index fa48708e..d31963d4 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -3,8 +3,10 @@ import numpy as np
import PIL
import torch
from PIL import Image
-from torch.utils.data import Dataset, DataLoader
+from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
+from collections import defaultdict
+from random import shuffle, choices
import random
import tqdm
@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
+ groups = defaultdict(list)
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
@@ -103,18 +105,25 @@ class PersonalizedBase(Dataset):
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
-
+ groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
self.length = len(self.dataset)
+ self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
+ if len(groups) > 1:
+ print("Buckets:")
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
+ print(f" {w}x{h}: {len(ids)}")
+ print()
+
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
@@ -137,9 +146,44 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
+
+class GroupedBatchSampler(Sampler):
+ def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ super().__init__(data_source)
+
+ n = len(data_source)
+ self.groups = data_source.groups
+ self.len = n_batch = n // batch_size
+ expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
+ self.base = [int(e) // batch_size for e in expected]
+ self.n_rand_batches = nrb = n_batch - sum(self.base)
+ self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
+ self.batch_size = batch_size
+
+ def __len__(self):
+ return self.len
+
+ def __iter__(self):
+ b = self.batch_size
+
+ for g in self.groups:
+ shuffle(g)
+
+ batches = []
+ for g in self.groups:
+ batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
+ for _ in range(self.n_rand_batches):
+ rand_group = choices(self.groups, self.probs)[0]
+ batches.append(choices(rand_group, k=b))
+
+ shuffle(batches)
+
+ yield from batches
+
+
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
- super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index ea653806..5593f88c 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
- data_np_low.resize(next_size)
+ data_np_low = np.resize(data_np_low, next_size)
data_np_low = data_np_low.reshape((h, -1, d))
- data_np_high.resize(next_size)
+ data_np_high = np.resize(data_np_high, next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 8b1981d5..734a4b6f 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,7 +2,7 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 3c1042ad..c0ac11d3 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
@@ -20,7 +20,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
if process_caption_deepbooru:
deepbooru.model.start()
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
finally:
@@ -109,8 +109,30 @@ def split_pic(image, inverse_xy, width, height, overlap_ratio):
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
-
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
+def center_crop(image: Image, w: int, h: int):
+ iw, ih = image.size
+ if ih / h < iw / w:
+ sw = w * ih / h
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+ else:
+ sh = h * iw / w
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
+
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+ iw, ih = image.size
+ err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
+ wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
+ default=None
+ )
+ return wh and center_crop(image, *wh)
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -194,6 +216,14 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False
+ if process_multicrop:
+ cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+ if cropped is not None:
+ save_pic(cropped, index, params, existing_caption=existing_caption)
+ else:
+ print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
+ process_default_resize = False
+
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, params, existing_caption=existing_caption)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index b915b091..5a7be422 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,9 +11,11 @@ import datetime
import csv
import safetensors.torch
+import numpy as np
from PIL import Image, PngImagePlugin
+from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -248,11 +250,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -291,6 +296,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def tensorboard_setup(log_directory):
+ os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
+ return SummaryWriter(
+ log_dir=os.path.join(log_directory, "tensorboard"),
+ flush_secs=shared.opts.training_tensorboard_flush_every)
+
+def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
+ tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
+ tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
+
+def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
+ tensorboard_writer.add_scalar(tag=tag,
+ scalar_value=value, global_step=step)
+
+def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
+ # Convert a pil image to a torch tensor
+ img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
+ len(pil_image.getbands()))
+ img_tensor = img_tensor.permute((2, 0, 1))
+
+ tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
@@ -316,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
@@ -369,13 +398,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
+
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = tensorboard_setup(log_directory)
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
latent_sampling_method = ds.latent_sampling_method
@@ -420,6 +452,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -476,9 +510,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
@@ -528,10 +561,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
+
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
@@ -549,7 +586,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_mid = '[{}]'.format(checkpoint.shorthash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
@@ -582,16 +619,18 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
return embedding, filename
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
try:
- embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint = checkpoint.shorthash
embedding.sd_checkpoint_name = checkpoint.model_name
if remove_cached_checksum:
embedding.cached_checksum = None
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 38b5f591..e945fd69 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,13 +8,13 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
diff --git a/modules/ui.py b/modules/ui.py
index e86a624b..af416d5f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -11,6 +11,7 @@ import tempfile
import time
import traceback
from functools import partial, reduce
+import warnings
import gradio as gr
import gradio.routes
@@ -19,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
+warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
@@ -180,7 +183,7 @@ def add_style(name: str, prompt: str, negative_prompt: str):
# reserialize all styles every time we save them
shared.prompt_styles.save_styles(shared.styles_filename)
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
@@ -197,16 +200,38 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+def apply_styles(prompt, prompt_neg, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
+
+
+def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
+ if mode in {0, 1, 3, 4}:
+ return [interrogation_function(ii_singles[mode]), None]
+ elif mode == 2:
+ return [interrogation_function(ii_singles[mode]["image"]), None]
+ elif mode == 5:
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ images = shared.listfiles(ii_input_dir)
+ print(f"Will process {len(images)} images.")
+ if ii_output_dir != "":
+ os.makedirs(ii_output_dir, exist_ok=True)
+ else:
+ ii_output_dir = ii_input_dir
+
+ for image in images:
+ img = Image.open(image)
+ filename = os.path.basename(image)
+ left, _ = os.path.splitext(filename)
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
+ return [gr_show(True), None]
def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB"))
-
return gr_show(True) if prompt is None else prompt
@@ -356,9 +381,9 @@ def create_toprow(is_img2img):
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
- with gr.Row():
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ with gr.Row(elem_id=f"{id_part}_generate_box"):
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
@@ -374,19 +399,14 @@ def create_toprow(is_img2img):
)
with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+ prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
+ create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(*args, **kwargs):
- import modules.ui_progress
-
- modules.ui_progress.setup_progressbar(*args, **kwargs)
+ pass
def apply_setting(key, value):
@@ -422,17 +442,16 @@ def apply_setting(key, value):
return value
-def update_generation_info(args):
- generation_info, html_info, img_index = args
+def update_generation_info(generation_info, html_info, img_index):
try:
generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
except Exception:
pass
# if the json parse or anything else fails, just return the old html_info
- return html_info
+ return html_info, gr.update()
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
@@ -479,8 +498,8 @@ Requested path was: {f}
else:
sp.Popen(["xdg-open", path])
- with gr.Column(variant='panel'):
- with gr.Group():
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
generation_info = None
@@ -513,10 +532,9 @@ Requested path was: {f}
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
+ _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
)
save.click(
@@ -531,7 +549,8 @@ Requested path was: {f}
outputs=[
download_files,
html_log,
- ]
+ ],
+ show_progress=False,
)
save_zip.click(
@@ -572,9 +591,9 @@ def create_sampler_and_steps_selection(choices, tabname):
def ordered_ui_categories():
- user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
@@ -590,22 +609,13 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
@@ -628,7 +638,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
+ with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
@@ -636,12 +646,12 @@ def create_ui():
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
- with FormRow(elem_id="txt2img_hires_fix_row2"):
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
@@ -682,10 +692,10 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
+ dummy_component,
txt2img_prompt,
txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
+ txt2img_prompt_styles,
steps,
sampler_index,
restore_faces,
@@ -780,34 +790,45 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
- with gr.Column(scale=1):
- pass
+ with FormRow().style(equal_height=False):
+ with gr.Column(variant='compact', elem_id="img2img_settings"):
+ copy_image_buttons = []
+ copy_image_destinations = {}
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
+ def add_copy_image_controls(tab_name, elem):
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
+
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
+ if name == tab_name:
+ gr.Button(title, interactive=False)
+ copy_image_destinations[name] = elem
+ continue
+
+ button = gr.Button(title)
+ copy_image_buttons.append((button, name, elem))
- with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img"):
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
inpaint_color_sketch_orig = gr.State(None)
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
def update_orig(image, state):
if image is not None:
@@ -824,36 +845,27 @@ def create_ui():
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
+ gr.HTML(f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
-
- with FormRow():
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+ return img
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- def select_img2img_tab(tab):
- return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
-
- for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
- elem.select(
- fn=lambda tab=i: select_img2img_tab(tab),
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js="switch_to_"+name.replace(" ", "_"),
inputs=[],
- outputs=[inpaint_controls, mask_alpha],
+ outputs=[],
)
with FormRow():
@@ -897,6 +909,35 @@ def create_ui():
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+ elif category == "inpaint":
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
@@ -919,10 +960,10 @@ def create_ui():
_js="submit_img2img",
inputs=[
dummy_component,
+ dummy_component,
img2img_prompt,
img2img_negative_prompt,
- img2img_prompt_style,
- img2img_prompt_style2,
+ img2img_prompt_styles,
init_img,
sketch,
init_img_with_mask,
@@ -961,23 +1002,37 @@ def create_ui():
show_progress=False,
)
+ interrogate_args = dict(
+ _js="get_img2img_tab_index",
+ inputs=[
+ dummy_component,
+ img2img_batch_input_dir,
+ img2img_batch_output_dir,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ init_img_inpaint,
+ ],
+ outputs=[img2img_prompt, dummy_component],
+ show_progress=False,
+ )
+
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
img2img_interrogate.click(
- fn=interrogate,
- inputs=[init_img],
- outputs=[img2img_prompt],
+ fn=lambda *args : process_interrogate(interrogate, *args),
+ **interrogate_args,
)
img2img_deepbooru.click(
- fn=interrogate_deepbooru,
- inputs=[init_img],
- outputs=[img2img_prompt],
+ fn=lambda *args : process_interrogate(interrogate_deepbooru, *args),
+ **interrogate_args,
)
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
+ style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
@@ -987,15 +1042,15 @@ def create_ui():
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
+ outputs=[txt2img_prompt_styles, img2img_prompt_styles],
)
- for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
+ for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click(
fn=apply_styles,
_js=js_func,
- inputs=[prompt, negative_prompt, style1, style2],
- outputs=[prompt, negative_prompt, style1, style2],
+ inputs=[prompt, negative_prompt, styles],
+ outputs=[prompt, negative_prompt, styles],
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1026,7 +1081,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
+ with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', elem_id="extras_single_tab"):
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
@@ -1127,10 +1182,10 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
+ with gr.Column(variant='compact'):
+ gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
- with FormRow():
+ with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@@ -1142,18 +1197,27 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
- interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
- config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+ with FormRow():
+ with gr.Column():
+ config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
- modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+ with gr.Column():
+ with FormRow():
+ bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
+ create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
- with gr.Column(variant='panel'):
- submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+ with gr.Row():
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+
+ with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
+ with gr.Group(elem_id="modelmerger_results_panel"):
+ modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
@@ -1204,6 +1268,7 @@ def create_ui():
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+ process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
@@ -1216,7 +1281,19 @@ def create_ui():
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
+
+ with gr.Column(visible=False) as process_multicrop_col:
+ gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
+ with gr.Row():
+ process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
+ process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
+ with gr.Row():
+ process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
+ process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
+ with gr.Row():
+ process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
+ process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1238,6 +1315,12 @@ def create_ui():
outputs=[process_focal_crop_row],
)
+ process_multicrop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_multicrop],
+ outputs=[process_multicrop_col],
+ )
+
def get_textual_inversion_template_names():
return sorted([x for x in textual_inversion.textual_inversion_templates])
@@ -1295,15 +1378,11 @@ def create_ui():
script_callbacks.ui_train_tabs_callback(params)
- with gr.Column():
- progressbar = gr.HTML(elem_id="ti_progressbar")
+ with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
- ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
- setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
@@ -1344,6 +1423,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
process_src,
process_dst,
process_width,
@@ -1360,6 +1440,13 @@ def create_ui():
process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
+ process_multicrop,
+ process_multicrop_mindim,
+ process_multicrop_maxdim,
+ process_multicrop_minarea,
+ process_multicrop_maxarea,
+ process_multicrop_objective,
+ process_multicrop_threshold,
],
outputs=[
ti_output,
@@ -1371,6 +1458,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_embedding_name,
embedding_learn_rate,
batch_size,
@@ -1403,6 +1491,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
@@ -1529,6 +1618,7 @@ def create_ui():
previous_section = None
current_tab = None
+ current_row = None
with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
@@ -1537,10 +1627,14 @@ def create_ui():
elem_id, text = item.section
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
+ gr.Group()
current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
previous_section = item.section
@@ -1555,6 +1649,7 @@ def create_ui():
components.append(component)
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
with gr.TabItem("Actions"):
@@ -1636,7 +1731,7 @@ def create_ui():
interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings"):
+ with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
@@ -1692,12 +1787,15 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list
- return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
+ modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
modelmerger_merge.click(
- fn=modelmerger,
+ fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
+ _js='modelmerger',
inputs=[
+ dummy_component,
primary_model_name,
secondary_model_name,
tertiary_model_name,
@@ -1707,13 +1805,14 @@ def create_ui():
custom_name,
checkpoint_format,
config_source,
+ bake_in_vae,
],
outputs=[
- submit_result,
primary_model_name,
secondary_model_name,
tertiary_model_name,
component_dict['sd_model_checkpoint'],
+ modelmerger_result,
]
)
@@ -1745,7 +1844,10 @@ def create_ui():
if saved_value is None:
ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value):
- print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ pass
+
+ # this warning is generally not useful;
+ # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value)
if init_field is not None:
@@ -1773,7 +1875,13 @@ def create_ui():
apply_field(x, 'value')
if type(x) == gr.Dropdown:
- apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
+ def check_dropdown(val):
+ if x.multiselect:
+ return all([value in x.choices for value in val])
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
@@ -1841,4 +1949,6 @@ xformers: {xformers_version}
gradio: {gr.__version__}
 • 
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
+ • 
+checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
diff --git a/modules/ui_progress.py b/modules/ui_progress.py
deleted file mode 100644
index 592fda55..00000000
--- a/modules/ui_progress.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import time
-
-import gradio as gr
-
-from modules.shared import opts
-
-import modules.shared as shared
-
-
-def calc_time_left(progress, threshold, label, force_display, show_eta):
- if progress == 0:
- return ""
- else:
- time_since_start = time.time() - shared.state.time_start
- eta = (time_since_start/progress)
- eta_relative = eta-time_since_start
- if (eta_relative > threshold and show_eta) or force_display:
- if eta_relative > 3600:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
- elif eta_relative > 60:
- return label + time.strftime('%M:%S', time.gmtime(eta_relative))
- else:
- return label + time.strftime('%Ss', time.gmtime(eta_relative))
- else:
- return ""
-
-
-def check_progress_call(id_part):
- if shared.state.job_count == 0:
- return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
-
- progress = 0
-
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
-
- # Show progress percentage and time left at the same moment, and base it also on steps done
- show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
-
- time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
- if time_left != "":
- shared.state.time_left_force_display = True
-
- progress = min(progress, 1)
-
- progressbar = ""
- if opts.show_progressbar:
- progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
-
- image = gr.update(visible=False)
- preview_visibility = gr.update(visible=False)
-
- if opts.show_progress_every_n_steps != 0:
- shared.state.set_current_image()
- image = shared.state.current_image
-
- if image is None:
- image = gr.update(value=None)
- else:
- preview_visibility = gr.update(visible=True)
-
- if shared.state.textinfo is not None:
- textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
- else:
- textinfo_result = gr.update(visible=False)
-
- return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
- shared.state.job_count = -1
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.textinfo = None
- shared.state.time_start = time.time()
- shared.state.time_left_force_display = False
-
- return check_progress_call(id_part)
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
- if textinfo is None:
- textinfo = gr.HTML(visible=False)
-
- check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
- check_progress.click(
- fn=lambda: check_progress_call(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
- check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
- check_progress_initial.click(
- fn=lambda: check_progress_call_initial(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 231680cb..a5bf5acb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -95,6 +95,7 @@ class UpscalerData:
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
+ self.local_data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model