aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/call_queue.py19
-rw-r--r--modules/devices.py31
-rw-r--r--modules/errors.py2
-rw-r--r--modules/extras.py10
-rw-r--r--modules/generation_parameters_copypaste.py3
-rw-r--r--modules/hashes.py29
-rw-r--r--modules/hypernetworks/hypernetwork.py17
-rw-r--r--modules/images.py5
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/processing.py12
-rw-r--r--modules/progress.py96
-rw-r--r--modules/prompt_parser.py7
-rw-r--r--modules/realesrgan_model.py12
-rw-r--r--modules/sd_hijack_clip.py17
-rw-r--r--modules/sd_models.py15
-rw-r--r--modules/sd_samplers.py19
-rw-r--r--modules/sd_vae.py196
-rw-r--r--modules/sd_vae_approx.py2
-rw-r--r--modules/shared.py42
-rw-r--r--modules/textual_inversion/logging.py2
-rw-r--r--modules/textual_inversion/preprocess.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/txt2img.py2
-rw-r--r--modules/ui.py137
-rw-r--r--modules/ui_progress.py101
-rw-r--r--modules/upscaler.py1
26 files changed, 440 insertions, 347 deletions
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 4cd49533..92097c15 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -4,7 +4,7 @@ import threading
import traceback
import time
-from modules import shared
+from modules import shared, progress
queue_lock = threading.Lock()
@@ -22,12 +22,23 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
- shared.state.begin()
+ # if the first argument is a string that says "task(...)", it is treated as a job id
+ if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ id_task = args[0]
+ progress.add_task_to_queue(id_task)
+ else:
+ id_task = None
with queue_lock:
- res = func(*args, **kwargs)
+ shared.state.begin()
+ progress.start_task(id_task)
+
+ try:
+ res = func(*args, **kwargs)
+ finally:
+ progress.finish_task(id_task)
- shared.state.end()
+ shared.state.end()
return res
diff --git a/modules/devices.py b/modules/devices.py
index caeb0276..206184fb 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -106,6 +106,36 @@ def autocast(disable=False):
return torch.autocast("cuda")
+class NansException(Exception):
+ pass
+
+
+def test_for_nans(x, where):
+ from modules import shared
+
+ if shared.cmd_opts.disable_nan_check:
+ return
+
+ if not torch.all(torch.isnan(x)).item():
+ return
+
+ if where == "unet":
+ message = "A tensor with all NaNs was produced in Unet."
+
+ if not shared.cmd_opts.no_half:
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
+
+ elif where == "vae":
+ message = "A tensor with all NaNs was produced in VAE."
+
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
+ else:
+ message = "A tensor with all NaNs was produced."
+
+ raise NansException(message)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
@@ -156,3 +186,4 @@ if has_mps():
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
+
diff --git a/modules/errors.py b/modules/errors.py
index a668c014..a10e8708 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -19,7 +19,7 @@ def display(e: Exception, task):
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
-The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
diff --git a/modules/extras.py b/modules/extras.py
index a03d558e..22668fcd 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -326,8 +326,14 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
print("Merging...")
+ chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
+
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
a = theta_0[key]
b = theta_1[key]
@@ -352,6 +358,10 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
+
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 593d99ef..a381ff59 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -37,6 +37,9 @@ def quote(text):
def image_from_url_text(filedata):
+ if filedata is None:
+ return None
+
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
diff --git a/modules/hashes.py b/modules/hashes.py
index ebfbd90c..b85a7580 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -34,31 +34,44 @@ def cache(subsection):
def calculate_sha256(filename):
hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
with open(filename, "rb") as f:
- for chunk in iter(lambda: f.read(4096), b""):
+ for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
-def sha256(filename, title):
+def sha256_from_cache(filename, title):
hashes = cache("hashes")
ondisk_mtime = os.path.getmtime(filename)
- if title in hashes:
- cached_sha256 = hashes[title].get("sha256", None)
- cached_mtime = hashes[title].get("mtime", 0)
+ if title not in hashes:
+ return None
+
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
+ return None
+
+ return cached_sha256
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
- if ondisk_mtime <= cached_mtime and cached_sha256 is not None:
- return cached_sha256
+ sha256_value = sha256_from_cache(filename, title)
+ if sha256_value is not None:
+ return sha256_value
print(f"Calculating sha256 for {filename}: ", end='')
sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {
- "mtime": ondisk_mtime,
+ "mtime": os.path.getmtime(filename),
"sha256": sha256_value,
}
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3aebefa8..c963fc40 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -561,6 +561,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
@@ -610,7 +611,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
-
+ loss_logging.append(_loss_step)
if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate)
@@ -629,7 +630,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -645,7 +645,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.opts.training_enable_tensorboard:
epoch_num = hypernetwork.step // len(ds)
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
-
+ mean_loss = sum(loss_logging) / len(loss_logging)
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
@@ -689,9 +689,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
-
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
- textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -701,7 +698,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
torch.cuda.set_rng_state_all(cuda_rng_state)
hypernetwork.train()
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer,
+ f"Validation at epoch {epoch_num}", image,
+ hypernetwork.step)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
diff --git a/modules/images.py b/modules/images.py
index c3a5fc8b..3b1c5f34 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -605,8 +605,9 @@ def read_info_from_image(image):
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
- items['exif comment'] = exif_comment
- geninfo = exif_comment
+ if exif_comment:
+ items['exif comment'] = exif_comment
+ geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
diff --git a/modules/img2img.py b/modules/img2img.py
index 79382cc1..2168c8e2 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
diff --git a/modules/processing.py b/modules/processing.py
index 849f6b19..9c3673de 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return image_conditioning
-class StableDiffusionProcessing():
+class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
@@ -102,7 +102,6 @@ class StableDiffusionProcessing():
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
- self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
@@ -156,6 +155,10 @@ class StableDiffusionProcessing():
self.all_subseeds = None
self.iteration = 0
+ @property
+ def sd_model(self):
+ return shared.sd_model
+
def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@@ -236,7 +239,6 @@ class StableDiffusionProcessing():
raise NotImplementedError()
def close(self):
- self.sd_model = None
self.sampler = None
@@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights() # make onchange call for changing SD model
- p.sd_model = shared.sd_model
if k == 'sd_vae':
sd_vae.reload_vae_weights() # make onchange call for changing VAE
@@ -608,6 +609,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+ for x in x_samples_ddim:
+ devices.test_for_nans(x, "vae")
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
diff --git a/modules/progress.py b/modules/progress.py
new file mode 100644
index 00000000..3327b883
--- /dev/null
+++ b/modules/progress.py
@@ -0,0 +1,96 @@
+import base64
+import io
+import time
+
+import gradio as gr
+from pydantic import BaseModel, Field
+
+from modules.shared import opts
+
+import modules.shared as shared
+
+
+current_task = None
+pending_tasks = {}
+finished_tasks = []
+
+
+def start_task(id_task):
+ global current_task
+
+ current_task = id_task
+ pending_tasks.pop(id_task, None)
+
+
+def finish_task(id_task):
+ global current_task
+
+ if current_task == id_task:
+ current_task = None
+
+ finished_tasks.append(id_task)
+ if len(finished_tasks) > 16:
+ finished_tasks.pop(0)
+
+
+def add_task_to_queue(id_job):
+ pending_tasks[id_job] = time.time()
+
+
+class ProgressRequest(BaseModel):
+ id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
+ id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+
+
+class ProgressResponse(BaseModel):
+ active: bool = Field(title="Whether the task is being worked on right now")
+ queued: bool = Field(title="Whether the task is in queue")
+ completed: bool = Field(title="Whether the task has already finished")
+ progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
+ eta: float = Field(default=None, title="ETA in secs")
+ live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
+ id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
+ textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+
+
+def setup_progress_api(app):
+ return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+
+
+def progressapi(req: ProgressRequest):
+ active = req.id_task == current_task
+ queued = req.id_task in pending_tasks
+ completed = req.id_task in finished_tasks
+
+ if not active:
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
+
+ progress = 0
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ progress = min(progress, 1)
+
+ elapsed_since_start = time.time() - shared.state.time_start
+ predicted_duration = elapsed_since_start / progress if progress > 0 else None
+ eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
+
+ id_live_preview = req.id_live_preview
+ shared.state.set_current_image()
+ if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
+ image = shared.state.current_image
+ if image is not None:
+ buffered = io.BytesIO()
+ image.save(buffered, format="png")
+ live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
+ id_live_preview = shared.state.id_live_preview
+ else:
+ live_preview = None
+ else:
+ live_preview = None
+
+ return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 870218db..69665372 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -274,6 +274,7 @@ re_attention = re.compile(r"""
:
""", re.X)
+re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text):
"""
@@ -339,7 +340,11 @@ def parse_prompt_attention(text):
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
- res.append([text, 1.0])
+ parts = re.split(re_break, text)
+ for i, part in enumerate(parts):
+ if i > 0:
+ res.append(["BREAK", -1])
+ res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 3ac0b97a..47f70251 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -38,13 +38,13 @@ class UpscalerRealESRGAN(Upscaler):
return img
info = self.load_model(path)
- if not os.path.exists(info.data_path):
+ if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
- model_path=info.data_path,
+ model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
@@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler):
def load_model(self, path):
try:
- info = None
- for scaler in self.scalers:
- if scaler.data_path == path:
- info = scaler
+ info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
if info is None:
print(f"Unable to find model info: {path}")
return None
- model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
- info.data_path = model_file
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 852afc66..9fa5c5c5 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
token_count = 0
last_comma = -1
- def next_chunk():
- """puts current chunk into the list of results and produces the next one - empty"""
+ def next_chunk(is_last=False):
+ """puts current chunk into the list of results and produces the next one - empty;
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
- token_count += len(chunk.tokens)
+ if is_last:
+ token_count += len(chunk.tokens)
+ else:
+ token_count += self.chunk_length
+
to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [self.id_end] * to_add
@@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
+ if text == 'BREAK' and weight == -1:
+ next_chunk()
+ continue
+
position = 0
while position < len(tokens):
token = tokens[position]
@@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0:
- next_chunk()
+ next_chunk(is_last=True)
return chunks, token_count
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 1fe6d11b..6a681cef 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -44,9 +44,11 @@ class CheckpointInfo:
self.title = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]']
- self.shorthash = None
- self.sha256 = None
+
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
+
+ self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -222,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
-def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
+def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -269,13 +271,14 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
- sd_vae.load_vae(model, vae_file)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ sd_vae.load_vae(model, vae_file, vae_source)
def enable_midas_autodownload():
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 01221b89..6261d1f7 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -138,9 +138,9 @@ def samples_to_image_grid(samples, approximation=None):
def store_latent(decoded):
state.current_latent = decoded
- if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
- shared.state.current_image = sample_to_image(decoded)
+ shared.state.assign_current_image(sample_to_image(decoded))
class InterruptedException(BaseException):
@@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler:
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
@@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler:
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
-
+
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -352,6 +351,13 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ devices.test_for_nans(x_out, "unet")
+
+ if opts.live_preview_content == "Prompt":
+ store_latent(x_out[0:uncond.shape[0]])
+ elif opts.live_preview_content == "Negative prompt":
+ store_latent(x_out[-uncond.shape[0]:])
+
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
@@ -423,7 +429,8 @@ class KDiffusionSampler:
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
- store_latent(latent)
+ if opts.live_preview_content == "Combined":
+ store_latent(latent)
self.last_latent = latent
if self.stop_at is not None and step > self.stop_at:
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 0a49daa1..b2af2ce7 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -9,23 +9,9 @@ import glob
from copy import deepcopy
-model_dir = "Stable-diffusion"
-model_path = os.path.abspath(os.path.join(models_path, model_dir))
-vae_dir = "VAE"
-vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
-
-
+vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-
-
-default_vae_dict = {"auto": "auto", "None": None, None: None}
-default_vae_list = ["auto", "None"]
-
-
-default_vae_values = [default_vae_dict[x] for x in default_vae_list]
-vae_dict = dict(default_vae_dict)
-vae_list = list(default_vae_list)
-first_load = True
+vae_dict = {}
base_vae = None
@@ -64,100 +50,71 @@ def restore_base_vae(model):
def get_filename(filepath):
- return os.path.splitext(os.path.basename(filepath))[0]
-
-
-def refresh_vae_list(vae_path=vae_path, model_path=model_path):
- global vae_dict, vae_list
- res = {}
- candidates = [
- *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
- *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
- *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
+ return os.path.basename(filepath)
+
+
+def refresh_vae_list():
+ vae_dict.clear()
+
+ paths = [
+ os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
+ os.path.join(sd_models.model_path, '**/*.vae.pt'),
+ os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
+ os.path.join(vae_path, '**/*.ckpt'),
+ os.path.join(vae_path, '**/*.pt'),
+ os.path.join(vae_path, '**/*.safetensors'),
]
- if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
- candidates.append(shared.cmd_opts.vae_path)
+
+ if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
+ paths += [
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
+ os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
+ ]
+
+ candidates = []
+ for path in paths:
+ candidates += glob.iglob(path, recursive=True)
+
for filepath in candidates:
name = get_filename(filepath)
- res[name] = filepath
- vae_list.clear()
- vae_list.extend(default_vae_list)
- vae_list.extend(list(res.keys()))
- vae_dict.clear()
- vae_dict.update(res)
- vae_dict.update(default_vae_dict)
- return vae_list
-
-
-def get_vae_from_settings(vae_file="auto"):
- # else, we load from settings, if not set to be default
- if vae_file == "auto" and shared.opts.sd_vae is not None:
- # if saved VAE settings isn't recognized, fallback to auto
- vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
- # if VAE selected but not found, fallback to auto
- if vae_file not in default_vae_values and not os.path.isfile(vae_file):
- vae_file = "auto"
- print(f"Selected VAE doesn't exist: {vae_file}")
- return vae_file
-
-
-def resolve_vae(checkpoint_file=None, vae_file="auto"):
- global first_load, vae_dict, vae_list
-
- # if vae_file argument is provided, it takes priority, but not saved
- if vae_file and vae_file not in default_vae_list:
- if not os.path.isfile(vae_file):
- print(f"VAE provided as function argument doesn't exist: {vae_file}")
- vae_file = "auto"
- # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
- if first_load and shared.cmd_opts.vae_path is not None:
- if os.path.isfile(shared.cmd_opts.vae_path):
- vae_file = shared.cmd_opts.vae_path
- shared.opts.data['sd_vae'] = get_filename(vae_file)
- else:
- print(f"VAE provided as command line argument doesn't exist: {vae_file}")
- # fallback to selector in settings, if vae selector not set to act as default fallback
- if not shared.opts.sd_vae_as_default:
- vae_file = get_vae_from_settings(vae_file)
- # vae-path cmd arg takes priority for auto
- if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
- if os.path.isfile(shared.cmd_opts.vae_path):
- vae_file = shared.cmd_opts.vae_path
- print(f"Using VAE provided as command line argument: {vae_file}")
- # if still not found, try look for ".vae.pt" beside model
- model_path = os.path.splitext(checkpoint_file)[0]
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.pt"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # if still not found, try look for ".vae.ckpt" beside model
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.ckpt"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # if still not found, try look for ".vae.safetensors" beside model
- if vae_file == "auto":
- vae_file_try = model_path + ".vae.safetensors"
- if os.path.isfile(vae_file_try):
- vae_file = vae_file_try
- print(f"Using VAE found similar to selected model: {vae_file}")
- # No more fallbacks for auto
- if vae_file == "auto":
- vae_file = None
- # Last check, just because
- if vae_file and not os.path.exists(vae_file):
- vae_file = None
-
- return vae_file
-
-
-def load_vae(model, vae_file=None):
- global first_load, vae_dict, vae_list, loaded_vae_file
+ vae_dict[name] = filepath
+
+
+def find_vae_near_checkpoint(checkpoint_file):
+ checkpoint_path = os.path.splitext(checkpoint_file)[0]
+ for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
+ if os.path.isfile(vae_location):
+ return vae_location
+
+ return None
+
+
+def resolve_vae(checkpoint_file):
+ if shared.cmd_opts.vae_path is not None:
+ return shared.cmd_opts.vae_path, 'from commandline argument'
+
+ is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
+
+ vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
+ if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
+ return vae_near_checkpoint, 'found near the checkpoint'
+
+ if shared.opts.sd_vae == "None":
+ return None, None
+
+ vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
+ if vae_from_options is not None:
+ return vae_from_options, 'specified in settings'
+
+ if not is_automatic:
+ print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
+
+ return None, None
+
+
+def load_vae(model, vae_file=None, vae_source="from unknown source"):
+ global vae_dict, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@@ -165,12 +122,12 @@ def load_vae(model, vae_file=None):
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
- print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
+ print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
else:
- assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
- print(f"Loading VAE weights from: {vae_file}")
+ assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
+ print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
@@ -191,14 +148,12 @@ def load_vae(model, vae_file=None):
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
- vae_list.append(vae_opt)
+
elif loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file
- first_load = False
-
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
@@ -211,7 +166,10 @@ def clear_loaded_vae():
loaded_vae_file = None
-def reload_vae_weights(sd_model=None, vae_file="auto"):
+unspecified = object()
+
+
+def reload_vae_weights(sd_model=None, vae_file=unspecified):
from modules import lowvram, devices, sd_hijack
if not sd_model:
@@ -219,7 +177,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
- vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
+
+ if vae_file == unspecified:
+ vae_file, vae_source = resolve_vae(checkpoint_file)
+ else:
+ vae_source = "from function argument"
if loaded_vae_file == vae_file:
return
@@ -231,7 +193,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_vae(sd_model, vae_file)
+ load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
@@ -239,5 +201,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
- print("VAE Weights loaded.")
+ print("VAE weights loaded.")
return sd_model
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 0a58542d..0027343a 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -36,7 +36,7 @@ def model():
if sd_vae_approx_model is None:
sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+ sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
diff --git a/modules/shared.py b/modules/shared.py
index a6c61db3..a708f23c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -64,6 +64,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
@@ -83,7 +84,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
+parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -116,6 +117,7 @@ restricted_opts = {
}
ui_reorder_categories = [
+ "inpaint",
"sampler",
"dimensions",
"cfg",
@@ -152,6 +154,7 @@ def reload_hypernetworks():
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
class State:
skipped = False
interrupted = False
@@ -165,9 +168,11 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
+ id_live_preview = 0
textinfo = None
time_start = None
need_restart = False
+ server_start = None
def skip(self):
self.skipped = True
@@ -176,7 +181,7 @@ class State:
self.interrupted = True
def nextjob(self):
- if opts.show_progress_every_n_steps == -1:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
@@ -206,6 +211,7 @@ class State:
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
+ self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
@@ -219,12 +225,12 @@ class State:
devices.torch_gc()
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not parallel_processing_allowed:
return
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
@@ -233,14 +239,19 @@ class State:
import modules.sd_samplers
if opts.show_progress_grid:
- self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
- self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
+ self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
+ def assign_current_image(self, image):
+ self.current_image = image
+ self.id_live_preview += 1
+
state = State()
+state.server_start = time.time()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
@@ -383,8 +394,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
- "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -422,10 +433,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
}))
options_templates.update(options_section(('ui', "User interface"), {
- "show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
- "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
@@ -444,6 +451,16 @@ options_templates.update(options_section(('ui', "User interface"), {
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
+options_templates.update(options_section(('ui', "Live previews"), {
+ "show_progressbar": OptionInfo(True, "Show progressbar"),
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
+ "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
+ "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
+}))
+
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -458,6 +475,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
options_templates.update()
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 31e50b64..734a4b6f 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,7 +2,7 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 3c1042ad..64abff4d 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 63935878..7e4a6d24 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
@@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description)
- shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
@@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
- shared.state.current_image = image
+ shared.state.assign_current_image(image)
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 5a71793b..e945fd69 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
diff --git a/modules/ui.py b/modules/ui.py
index db198a47..e1f98d23 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -356,9 +356,9 @@ def create_toprow(is_img2img):
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1):
- with gr.Row():
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
+ with gr.Row(elem_id=f"{id_part}_generate_box"):
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
@@ -381,9 +381,7 @@ def create_toprow(is_img2img):
def setup_progressbar(*args, **kwargs):
- import modules.ui_progress
-
- modules.ui_progress.setup_progressbar(*args, **kwargs)
+ pass
def apply_setting(key, value):
@@ -476,8 +474,8 @@ Requested path was: {f}
else:
sp.Popen(["xdg-open", path])
- with gr.Column(variant='panel'):
- with gr.Group():
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
generation_info = None
@@ -569,9 +567,9 @@ def create_sampler_and_steps_selection(choices, tabname):
def ordered_ui_categories():
- user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
@@ -592,15 +590,6 @@ def create_ui():
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
- with gr.Row(elem_id='txt2img_progress_row'):
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="txt2img_progressbar")
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
@@ -679,6 +668,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
+ dummy_component,
txt2img_prompt,
txt2img_negative_prompt,
txt2img_prompt_styles,
@@ -778,32 +768,43 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
- with gr.Row(elem_id='img2img_progress_row'):
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
- with gr.Column(scale=1):
- pass
-
- with gr.Column(scale=1):
- progressbar = gr.HTML(elem_id="img2img_progressbar")
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- setup_progressbar(progressbar, img2img_preview, 'img2img')
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
+ copy_image_buttons = []
+ copy_image_destinations = {}
+
+ def add_copy_image_controls(tab_name, elem):
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
+
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
+ if name == tab_name:
+ gr.Button(title, interactive=False)
+ copy_image_destinations[name] = elem
+ continue
+
+ button = gr.Button(title)
+ copy_image_buttons.append((button, name, elem))
+
with gr.Tabs(elem_id="mode_img2img"):
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
inpaint_color_sketch_orig = gr.State(None)
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
def update_orig(image, state):
if image is not None:
@@ -820,36 +821,27 @@ def create_ui():
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
+ gr.HTML(f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
- with FormRow():
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+ return img
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- def select_img2img_tab(tab):
- return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
-
- for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
- elem.select(
- fn=lambda tab=i: select_img2img_tab(tab),
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js="switch_to_"+name.replace(" ", "_"),
inputs=[],
- outputs=[inpaint_controls, mask_alpha],
+ outputs=[],
)
with FormRow():
@@ -893,6 +885,35 @@ def create_ui():
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+ elif category == "inpaint":
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
@@ -915,6 +936,7 @@ def create_ui():
_js="submit_img2img",
inputs=[
dummy_component,
+ dummy_component,
img2img_prompt,
img2img_negative_prompt,
img2img_prompt_styles,
@@ -1291,15 +1313,11 @@ def create_ui():
script_callbacks.ui_train_tabs_callback(params)
- with gr.Column():
- progressbar = gr.HTML(elem_id="ti_progressbar")
+ with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
- ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
- setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
@@ -1340,6 +1358,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
process_src,
process_dst,
process_width,
@@ -1367,6 +1386,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_embedding_name,
embedding_learn_rate,
batch_size,
@@ -1399,6 +1419,7 @@ def create_ui():
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
+ dummy_component,
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
@@ -1849,4 +1870,6 @@ xformers: {xformers_version}
gradio: {gr.__version__}
 • 
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
+ • 
+checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
diff --git a/modules/ui_progress.py b/modules/ui_progress.py
deleted file mode 100644
index 592fda55..00000000
--- a/modules/ui_progress.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import time
-
-import gradio as gr
-
-from modules.shared import opts
-
-import modules.shared as shared
-
-
-def calc_time_left(progress, threshold, label, force_display, show_eta):
- if progress == 0:
- return ""
- else:
- time_since_start = time.time() - shared.state.time_start
- eta = (time_since_start/progress)
- eta_relative = eta-time_since_start
- if (eta_relative > threshold and show_eta) or force_display:
- if eta_relative > 3600:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
- elif eta_relative > 60:
- return label + time.strftime('%M:%S', time.gmtime(eta_relative))
- else:
- return label + time.strftime('%Ss', time.gmtime(eta_relative))
- else:
- return ""
-
-
-def check_progress_call(id_part):
- if shared.state.job_count == 0:
- return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
-
- progress = 0
-
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
-
- # Show progress percentage and time left at the same moment, and base it also on steps done
- show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
-
- time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
- if time_left != "":
- shared.state.time_left_force_display = True
-
- progress = min(progress, 1)
-
- progressbar = ""
- if opts.show_progressbar:
- progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
-
- image = gr.update(visible=False)
- preview_visibility = gr.update(visible=False)
-
- if opts.show_progress_every_n_steps != 0:
- shared.state.set_current_image()
- image = shared.state.current_image
-
- if image is None:
- image = gr.update(value=None)
- else:
- preview_visibility = gr.update(visible=True)
-
- if shared.state.textinfo is not None:
- textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
- else:
- textinfo_result = gr.update(visible=False)
-
- return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
- shared.state.job_count = -1
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.textinfo = None
- shared.state.time_start = time.time()
- shared.state.time_left_force_display = False
-
- return check_progress_call(id_part)
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
- if textinfo is None:
- textinfo = gr.HTML(visible=False)
-
- check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
- check_progress.click(
- fn=lambda: check_progress_call(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
-
- check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
- check_progress_initial.click(
- fn=lambda: check_progress_call_initial(id_part),
- show_progress=False,
- inputs=[],
- outputs=[progressbar, preview, preview, textinfo],
- )
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 231680cb..a5bf5acb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -95,6 +95,7 @@ class UpscalerData:
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
+ self.local_data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model