aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/config_states.py2
-rw-r--r--modules/extensions.py19
-rw-r--r--modules/extras.py12
-rw-r--r--modules/processing.py20
-rw-r--r--modules/sd_models.py1
-rw-r--r--modules/sd_samplers.py8
-rw-r--r--modules/sd_samplers_common.py35
-rw-r--r--modules/sd_samplers_compvis.py8
-rw-r--r--modules/sd_samplers_kdiffusion.py22
-rw-r--r--modules/sd_vae_taesd.py88
-rw-r--r--modules/shared.py4
-rw-r--r--modules/ui.py10
-rw-r--r--modules/ui_extensions.py26
-rw-r--r--modules/ui_extra_networks.py19
14 files changed, 217 insertions, 57 deletions
diff --git a/modules/config_states.py b/modules/config_states.py
index 75da862a..db65bcdb 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -83,6 +83,8 @@ def get_extension_config():
ext_config = {}
for ext in extensions.extensions:
+ ext.read_info_from_repo()
+
entry = {
"name": ext.name,
"path": ext.path,
diff --git a/modules/extensions.py b/modules/extensions.py
index bc2c0450..359a7aa5 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,8 +1,8 @@
import os
import sys
+import threading
import traceback
-import time
import git
from modules import shared
@@ -24,6 +24,8 @@ def active():
class Extension:
+ lock = threading.Lock()
+
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
self.path = path
@@ -42,8 +44,13 @@ class Extension:
if self.is_builtin or self.have_info_from_repo:
return
- self.have_info_from_repo = True
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+ self.do_read_info_from_repo()
+
+ def do_read_info_from_repo(self):
repo = None
try:
if os.path.exists(os.path.join(self.path, ".git")):
@@ -58,18 +65,18 @@ class Extension:
try:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
- head = repo.head.commit
self.commit_date = repo.head.commit.committed_date
- ts = time.asctime(time.gmtime(self.commit_date))
if repo.active_branch:
self.branch = repo.active_branch.name
- self.commit_hash = head.hexsha
- self.version = f'{self.commit_hash[:8]} ({ts})'
+ self.commit_hash = repo.head.commit.hexsha
+ self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
except Exception as ex:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
+ self.have_info_from_repo = True
+
def list_files(self, subdir, extension):
from modules import scripts
diff --git a/modules/extras.py b/modules/extras.py
index bdf9b3b7..830b53aa 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -242,9 +242,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+ metadata = None
if save_metadata:
+ metadata = {"format": "pt"}
+
merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256,
@@ -262,15 +264,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
}
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+ sd_merge_models = {}
+
def add_model_metadata(checkpoint_info):
checkpoint_info.calculate_shorthash()
- metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ sd_merge_models[checkpoint_info.sha256] = {
"name": checkpoint_info.name,
"legacy_hash": checkpoint_info.hash,
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
}
- metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+ sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
add_model_metadata(primary_model_info)
if secondary_model_info:
@@ -278,7 +282,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info:
add_model_metadata(tertiary_model_info)
- metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+ metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
diff --git a/modules/processing.py b/modules/processing.py
index 94fe2625..cd63b9a6 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -316,6 +316,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
+ self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
@@ -480,6 +481,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
enable_hr = getattr(p, 'enable_hr', False)
+ uses_ensd = opts.eta_noise_seed_delta != 0
+ if uses_ensd:
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
+
generation_params = {
"Steps": p.steps,
"Sampler": p.sampler_name,
@@ -496,17 +501,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
- "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
"Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+ **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
}
- generation_params.update(p.extra_generation_params)
-
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
@@ -678,12 +682,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- step_multiplier = 1
- if not shared.opts.dont_fix_second_order_samplers_schedule:
- try:
- step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
- except Exception:
- pass
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+ step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index dddbc6e1..e612be10 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -540,7 +540,6 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
- checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 4f1bf21d..f22aad8f 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -14,12 +14,18 @@ samplers_for_img2img = []
samplers_map = {}
-def create_sampler(name, model):
+def find_sampler_config(name):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
+ return config
+
+
+def create_sampler(name, model):
+ config = find_sampler_config(name)
+
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index bc074238..763829f1 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
+from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@@ -22,7 +22,7 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
def single_sample_to_image(sample, approximation=None):
@@ -30,15 +30,19 @@ def single_sample_to_image(sample, approximation=None):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
+ x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
+ elif approximation == 3:
+ x_sample = sample * 1.5
+ x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
+
return Image.fromarray(x_sample)
@@ -58,6 +62,25 @@ def store_latent(decoded):
shared.state.assign_current_image(sample_to_image(decoded))
+def is_sampler_using_eta_noise_seed_delta(p):
+ """returns whether sampler from config will use eta noise seed delta for image creation"""
+
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+
+ eta = p.eta
+
+ if eta is None and p.sampler is not None:
+ eta = p.sampler.eta
+
+ if eta is None and sampler_config is not None:
+ eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
+
+ if eta == 0:
+ return False
+
+ return sampler_config.options.get("uses_ensd", False)
+
+
class InterruptedException(BaseException):
pass
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index b1ee3be7..bdae8b40 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc
samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
]
@@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler:
self.update_step(x)
def initialize(self, p):
- self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ if self.is_ddim:
+ self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ else:
+ self.eta = 0.0
+
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 61f23ad7..552c6c64 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -11,23 +11,23 @@ from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
]
samplers_data_k_diffusion = [
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
new file mode 100644
index 00000000..5e8496e8
--- /dev/null
+++ b/modules/sd_vae_taesd.py
@@ -0,0 +1,88 @@
+"""
+Tiny AutoEncoder for Stable Diffusion
+(DNN for encoding / decoding SD's latent space)
+
+https://github.com/madebyollin/taesd
+"""
+import os
+import torch
+import torch.nn as nn
+
+from modules import devices, paths_internal
+
+sd_vae_taesd = None
+
+
+def conv(n_in, n_out, **kwargs):
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+
+class Clamp(nn.Module):
+ @staticmethod
+ def forward(x):
+ return torch.tanh(x / 3) * 3
+
+
+class Block(nn.Module):
+ def __init__(self, n_in, n_out):
+ super().__init__()
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.fuse = nn.ReLU()
+
+ def forward(self, x):
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+def decoder():
+ return nn.Sequential(
+ Clamp(), conv(4, 64), nn.ReLU(),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), conv(64, 3),
+ )
+
+
+class TAESD(nn.Module):
+ latent_magnitude = 3
+ latent_shift = 0.5
+
+ def __init__(self, decoder_path="taesd_decoder.pth"):
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
+ super().__init__()
+ self.decoder = decoder()
+ self.decoder.load_state_dict(
+ torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+
+ @staticmethod
+ def unscale_latents(x):
+ """[0, 1] -> raw latents"""
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+
+def download_model(model_path):
+ model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
+
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+ print(f'Downloading TAESD decoder to: {model_path}')
+ torch.hub.download_url_to_file(model_url, model_path)
+
+
+def model():
+ global sd_vae_taesd
+
+ if sd_vae_taesd is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
+ download_model(model_path)
+
+ if os.path.exists(model_path):
+ sd_vae_taesd = TAESD(model_path)
+ sd_vae_taesd.eval()
+ sd_vae_taesd.to(devices.device, devices.dtype)
+ else:
+ raise FileNotFoundError('TAESD model not found')
+
+ return sd_vae_taesd.decoder
diff --git a/modules/shared.py b/modules/shared.py
index 07f18b1b..165509ea 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -448,7 +448,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
- "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}).info("Full = slow but pretty; Approx NN = fast but low quality; Approx cheap = super fast but terrible otherwise"),
+ "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
}))
@@ -458,8 +458,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
+ 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
diff --git a/modules/ui.py b/modules/ui.py
index ff25c4ce..8e51e782 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1841,15 +1841,15 @@ def versions_html():
return f"""
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
- •
+&#x2000;•&#x2000;
python: <span title="{sys.version}">{python_version}</span>
- •
+&#x2000;•&#x2000;
torch: {getattr(torch, '__long_version__',torch.__version__)}
- •
+&#x2000;•&#x2000;
xformers: {xformers_version}
- •
+&#x2000;•&#x2000;
gradio: {gr.__version__}
- •
+&#x2000;•&#x2000;
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index af497733..d7a0f685 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,6 +1,7 @@
import json
import os.path
import sys
+import threading
import time
from datetime import datetime
import traceback
@@ -140,7 +141,9 @@ def extension_table():
<tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
- <th><abbr title="Extension version">Version</abbr></th>
+ <th>Branch</th>
+ <th>Version</th>
+ <th>Date</th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
@@ -148,6 +151,7 @@ def extension_table():
"""
for ext in extensions.extensions:
+ ext: extensions.Extension
ext.read_info_from_repo()
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
@@ -169,7 +173,9 @@ def extension_table():
<tr>
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td>
+ <td>{ext.branch}</td>
<td>{version_link}</td>
+ <td>{time.asctime(time.gmtime(ext.commit_date))}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -484,11 +490,18 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
return code, list(tags)
+def preload_extensions_git_metadata():
+ for extension in extensions.extensions:
+ extension.read_info_from_repo()
+
+
def create_ui():
import modules.ui
config_states.list_config_states()
+ threading.Thread(target=preload_extensions_git_metadata).start()
+
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions"):
with gr.TabItem("Installed", id="installed"):
@@ -508,7 +521,8 @@ def create_ui():
</span>
"""
info = gr.HTML(html)
- extensions_table = gr.HTML(lambda: extension_table())
+ extensions_table = gr.HTML('Loading...')
+ ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
apply.click(
fn=apply_and_restart,
@@ -579,9 +593,9 @@ def create_ui():
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
- fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
inputs=[install_dirname, install_url, install_branch],
- outputs=[extensions_table, install_result],
+ outputs=[install_url, extensions_table, install_result],
)
with gr.TabItem("Backup/Restore"):
@@ -595,7 +609,8 @@ def create_ui():
config_save_button = gr.Button(value="Save Current Config")
config_states_info = gr.HTML("")
- config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
+ config_states_table = gr.HTML("Loading...")
+ ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
@@ -608,4 +623,5 @@ def create_ui():
outputs=[config_states_table],
)
+
return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 0baccf56..752cf2b8 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -268,7 +268,7 @@ def create_ui(container, button, tabname):
with gr.Tab(page.title, id=page_id):
elem_id = f"{tabname}_{page_id}_cards_html"
- page_elem = gr.HTML('', elem_id=elem_id)
+ page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
@@ -282,13 +282,24 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible):
is_visible = not is_visible
- if is_visible and not ui.pages_contents:
+ return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+
+ def fill_tabs(is_empty):
+ """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
+
+ if not ui.pages_contents:
refresh()
- return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
+ if is_empty:
+ return True, *ui.pages_contents
+
+ return True, *[gr.update() for _ in ui.pages_contents]
state_visible = gr.State(value=False)
- button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
+
+ state_empty = gr.State(value=True)
+ button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
def refresh():
for pg in ui.stored_extra_pages: