aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/Lora/lora.py21
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py1
-rw-r--r--html/extra-networks-card.html2
-rw-r--r--javascript/extraNetworks.js40
-rw-r--r--javascript/imageviewer.js2
-rw-r--r--launch.py30
-rw-r--r--modules/api/api.py83
-rw-r--r--modules/api/models.py2
-rw-r--r--modules/mac_specific.py3
-rw-r--r--modules/memmon.py12
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py2
-rw-r--r--modules/sd_models.py24
-rw-r--r--modules/sd_vae_approx.py5
-rw-r--r--modules/shared.py1
-rw-r--r--modules/timer.py3
-rw-r--r--modules/ui.py3
-rw-r--r--modules/ui_extensions.py2
-rw-r--r--modules/ui_extra_networks.py13
-rw-r--r--requirements_versions.txt2
-rw-r--r--scripts/xyz_grid.py15
-rw-r--r--style.css65
-rw-r--r--test/basic_features/extras_test.py8
-rw-r--r--test/basic_features/img2img_test.py8
-rw-r--r--test/server_poll.py6
-rw-r--r--webui.py24
25 files changed, 313 insertions, 64 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index cb8f1d36..8937b585 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -3,7 +3,9 @@ import os
import re
import torch
-from modules import shared, devices, sd_models
+from modules import shared, devices, sd_models, errors
+
+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
@@ -43,6 +45,23 @@ class LoraOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
+ self.metadata = {}
+
+ _, ext = os.path.splitext(filename)
+ if ext.lower() == ".safetensors":
+ try:
+ self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ except Exception as e:
+ errors.display(e, f"reading lora {filename}")
+
+ if self.metadata:
+ m = {}
+ for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
+ m[k] = v
+
+ self.metadata = m
+
+ self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
class LoraModule:
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 8d32052e..68b11332 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -23,6 +23,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
}
def allowed_directories_for_previews(self):
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index 8612396d..1bf3fc30 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -1,4 +1,6 @@
<div class='card' {preview_html} onclick={card_clicked}>
+ {metadata_button}
+
<div class='actions'>
<div class='additional'>
<ul>
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 5781df4f..2fb87cd5 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
- var close = gradioApp().getElementById(tabname+'_extra_close')
search.classList.add('search')
tabs.appendChild(search)
tabs.appendChild(refresh)
- tabs.appendChild(close)
search.addEventListener("input", function(evt){
searchTerm = search.value.toLowerCase()
@@ -104,4 +102,40 @@ function extraNetworksSearchButton(tabs_id, event){
searchTextarea.value = text
updateInput(searchTextarea)
-} \ No newline at end of file
+}
+
+var globalPopup = null;
+var globalPopupInner = null;
+function popup(contents){
+ if(! globalPopup){
+ globalPopup = document.createElement('div')
+ globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
+ globalPopup.classList.add('global-popup');
+
+ var close = document.createElement('div')
+ close.classList.add('global-popup-close');
+ close.onclick = function(){ globalPopup.style.display = "none"; };
+ close.title = "Close";
+ globalPopup.appendChild(close)
+
+ globalPopupInner = document.createElement('div')
+ globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
+ globalPopupInner.classList.add('global-popup-inner');
+ globalPopup.appendChild(globalPopupInner)
+
+ gradioApp().appendChild(globalPopup);
+ }
+
+ globalPopupInner.innerHTML = '';
+ globalPopupInner.appendChild(contents);
+
+ globalPopup.style.display = "flex";
+}
+
+function extraNetworksShowMetadata(text){
+ elem = document.createElement('pre')
+ elem.classList.add('popup-metadata');
+ elem.textContent = text;
+
+ popup(elem);
+}
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index aac2ee82..28e748b7 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -11,7 +11,7 @@ function showModal(event) {
if (modalImage.style.display === 'none') {
lb.style.setProperty('background-image', 'url(' + source.src + ')');
}
- lb.style.display = "block";
+ lb.style.display = "flex";
lb.focus()
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
diff --git a/launch.py b/launch.py
index 0868f8a9..b943fed2 100644
--- a/launch.py
+++ b/launch.py
@@ -8,6 +8,14 @@ import platform
import argparse
import json
+parser = argparse.ArgumentParser(add_help=False)
+parser.add_argument("--ui-settings-file", type=str, default='config.json')
+parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
+args, _ = parser.parse_known_args(sys.argv)
+
+script_path = os.path.dirname(__file__)
+data_path = os.getcwd()
+
dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable
@@ -122,7 +130,7 @@ def is_installed(package):
def repo_dir(name):
- return os.path.join(dir_repos, name)
+ return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None):
@@ -215,7 +223,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', []))
- return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
+ return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
def run_extensions_installers(settings_file):
@@ -252,10 +260,6 @@ def prepare_environment():
sys.argv += shlex.split(commandline_args)
- parser = argparse.ArgumentParser(add_help=False)
- parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
- args, _ = parser.parse_known_args(sys.argv)
-
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
@@ -306,7 +310,7 @@ def prepare_environment():
if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok")
- os.makedirs(dir_repos, exist_ok=True)
+ os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
@@ -315,9 +319,11 @@ def prepare_environment():
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
if not is_installed("lpips"):
- run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
+ run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
- run_pip(f"install -r {requirements_file}", "requirements for Web UI")
+ if not os.path.isfile(requirements_file):
+ requirements_file = os.path.join(script_path, requirements_file)
+ run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
run_extensions_installers(settings_file=args.ui_settings_file)
@@ -325,7 +331,7 @@ def prepare_environment():
version_check(commit)
if update_all_extensions:
- git_pull_recursive(dir_extensions)
+ git_pull_recursive(os.path.join(data_path, dir_extensions))
if "--exit" in sys.argv:
print("Exiting because of --exit argument")
@@ -341,7 +347,7 @@ def tests(test_dir):
sys.argv.append("--api")
if "--ckpt" not in sys.argv:
sys.argv.append("--ckpt")
- sys.argv.append("./test/test_files/empty.pt")
+ sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
@@ -350,7 +356,7 @@ def tests(test_dir):
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
os.environ['COMMANDLINE_ARGS'] = ""
- with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
+ with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
import test.server_poll
diff --git a/modules/api/api.py b/modules/api/api.py
index 376f7f04..35e17afc 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -164,14 +164,10 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
- def get_script(self, script_name, script_runner):
- if script_name is None:
+ def get_selectable_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
return None, None
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
-
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx]
return script, script_idx
@@ -182,8 +178,49 @@ class Api:
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
+ def get_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
+ return None, None
+
+ script_idx = script_name_to_index(script_name, script_runner.scripts)
+ return script_runner.scripts[script_idx]
+
+ def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
+ #find max idx from the scripts in runner and generate a none array to init script_args
+ last_arg_index = 1
+ for script in script_runner.scripts:
+ if last_arg_index < script.args_to:
+ last_arg_index = script.args_to
+ # None everywhere except position 0 to initialize script args
+ script_args = [None]*last_arg_index
+ # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
+ if selectable_scripts:
+ script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
+ script_args[0] = selectable_idx + 1
+ else:
+ # when [0] = 0 no selectable script to run
+ script_args[0] = 0
+
+ # Now check for always on scripts
+ if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
+ for alwayson_script_name in request.alwayson_scripts.keys():
+ alwayson_script = self.get_script(alwayson_script_name, script_runner)
+ if alwayson_script == None:
+ raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
+ # Selectable script in always on script param check
+ if alwayson_script.alwayson == False:
+ raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
+ # always on script with no arg should always run so you don't really need to add them to the requests
+ if "args" in request.alwayson_scripts[alwayson_script_name]:
+ script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
+ return script_args
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
- script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
+ script_runner = scripts.scripts_txt2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
@@ -195,20 +232,26 @@ class Api:
args = vars(populate)
args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
+ p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin()
- if script is not None:
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args)
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
@@ -221,12 +264,16 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
- script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
-
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
+ script_runner = scripts.scripts_img2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(True)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
+
populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": not img2imgreq.save_images,
@@ -239,6 +286,10 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
@@ -246,14 +297,16 @@ class Api:
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin()
- if script is not None:
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args)
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
diff --git a/modules/api/models.py b/modules/api/models.py
index fa1c40df..4a70f440 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -106,6 +106,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model()
@@ -122,6 +123,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model()
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index ddcea53b..18e6ff72 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
@@ -45,7 +45,6 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
diff --git a/modules/memmon.py b/modules/memmon.py
index a7060f58..4018edcc 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
self.data = defaultdict(int)
try:
- torch.cuda.mem_get_info()
+ self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True
+ def cuda_mem_get_info(self):
+ index = self.device.index if self.device.index is not None else torch.cuda.current_device()
+ return torch.cuda.mem_get_info(index)
+
def run(self):
if self.disabled:
return
@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.clear()
continue
- self.data["min_free"] = torch.cuda.mem_get_info()[0]
+ self.data["min_free"] = self.cuda_mem_get_info()[0]
while self.run_flag.is_set():
- free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
+ free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate)
@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
def read(self):
if not self.disabled:
- free, total = torch.cuda.mem_get_info()
+ free, total = self.cuda_mem_get_info()
self.data["free"] = free
self.data["total"] = total
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
index df63d1bc..e9a093a2 100644
--- a/modules/models/diffusion/uni_pc/uni_pc.py
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -719,7 +719,7 @@ class UniPC:
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = (
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
)
if x_t is None:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 93959f55..5f57ec0c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
+def read_metadata_from_safetensors(filename):
+ import json
+
+ with open(filename, mode="rb") as file:
+ metadata_len = file.read(8)
+ metadata_len = int.from_bytes(metadata_len, "little")
+ json_start = file.read(2)
+
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
+ json_data = json_start + file.read(metadata_len-2)
+ json_obj = json.loads(json_data)
+
+ res = {}
+ for k, v in json_obj.get("__metadata__", {}).items():
+ res[k] = v
+ if isinstance(v, str) and v[0] == '{':
+ try:
+ res[k] = json.loads(v)
+ except Exception as e:
+ pass
+
+ return res
+
+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 0027343a..e2f00468 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -35,8 +35,11 @@ def model():
global sd_vae_approx_model
if sd_vae_approx_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
+ sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
diff --git a/modules/shared.py b/modules/shared.py
index 2fb9e3b5..f28a12cc 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -714,6 +714,7 @@ class TotalTQDM:
def clear(self):
if self._tqdm is not None:
+ self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
diff --git a/modules/timer.py b/modules/timer.py
index 57a4f17a..ba92be33 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -33,3 +33,6 @@ class Timer:
res += ")"
return res
+
+ def reset(self):
+ self.__init__()
diff --git a/modules/ui.py b/modules/ui.py
index 621ae952..7e603332 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1751,7 +1751,8 @@ def create_ui():
def reload_javascript():
- head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
+ script_js = os.path.join(script_path, "script.js")
+ head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index bd4308ef..df75a925 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -304,7 +304,7 @@ def create_ui():
with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 85f0af4c..cdfd6f2a 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -30,8 +30,8 @@ def add_pages_to_demo(app):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
+ if ext not in (".png", ".jpg", ".webp"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -124,6 +124,12 @@ class ExtraNetworksPage:
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ metadata_button = ""
+ metadata = item.get("metadata")
+ if metadata:
+ metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
+ metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
+
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item.get("prompt", None),
@@ -134,6 +140,7 @@ class ExtraNetworksPage:
"card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
+ "metadata_button": metadata_button,
}
return self.card_page.format(**args)
@@ -213,7 +220,6 @@ def create_ui(container, button, tabname):
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
- button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
@@ -224,7 +230,6 @@ def create_ui(container, button, tabname):
state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
- button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh():
res = []
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 41e0ccc5..0031c616 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -27,4 +27,4 @@ GitPython==3.1.30
torchsde==0.2.5
safetensors==0.2.7
httpcore<=0.15
-fastapi==0.90.1
+fastapi==0.94.0
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 9a0678fa..ce584981 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -132,6 +132,20 @@ def apply_uni_pc_order(p, x, xs):
opts.data["uni_pc_order"] = min(x, p.steps - 1)
+def apply_face_restore(p, opt, x):
+ opt = opt.lower()
+ if opt == 'codeformer':
+ is_active = True
+ p.face_restoration_model = 'CodeFormer'
+ elif opt == 'gfpgan':
+ is_active = True
+ p.face_restoration_model = 'GFPGAN'
+ else:
+ is_active = opt in ('true', 'yes', 'y', '1')
+
+ p.restore_faces = is_active
+
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -210,6 +224,7 @@ axis_options = [
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
+ AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
]
diff --git a/style.css b/style.css
index fc0f47a6..3eac2b17 100644
--- a/style.css
+++ b/style.css
@@ -362,6 +362,46 @@ input[type="range"]{
height: 100%;
}
+.popup-metadata{
+ color: black;
+ background: white;
+ display: inline-block;
+ padding: 1em;
+ white-space: pre-wrap;
+}
+
+.global-popup{
+ display: flex;
+ position: fixed;
+ z-index: 1001;
+ left: 0;
+ top: 0;
+ width: 100%;
+ height: 100%;
+ overflow: auto;
+ background-color: rgba(20, 20, 20, 0.95);
+}
+
+
+.global-popup-close:before {
+ content: "×";
+}
+
+.global-popup-close{
+ position: fixed;
+ right: 0.25em;
+ top: 0;
+ cursor: pointer;
+ color: white;
+ font-size: 32pt;
+}
+
+.global-popup-inner{
+ display: inline-block;
+ margin: auto;
+ padding: 2em;
+}
+
#lightboxModal{
display: none;
position: fixed;
@@ -436,9 +476,7 @@ input[type="range"]{
#modalImage {
display: block;
- margin-left: auto;
- margin-right: auto;
- margin-top: auto;
+ margin: auto;
width: auto;
}
@@ -839,6 +877,27 @@ footer {
margin-left: 0.5em;
}
+
+.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{
+ content: "🛈";
+}
+.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
+ display: none;
+ position: absolute;
+ right: 0;
+ color: white;
+ text-shadow: 2px 2px 3px black;
+ padding: 0.25em;
+ font-size: 22pt;
+}
+.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
+ display: inline-block;
+}
+.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{
+ color: red;
+}
+
+
.extra-network-thumbs {
display: flex;
flex-flow: row wrap;
diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py
index 0170c511..8ed98747 100644
--- a/test/basic_features/extras_test.py
+++ b/test/basic_features/extras_test.py
@@ -1,7 +1,9 @@
+import os
import unittest
import requests
from gradio.processing_utils import encode_pil_to_base64
from PIL import Image
+from modules.paths import script_path
class TestExtrasWorking(unittest.TestCase):
def setUp(self):
@@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase):
"upscaler_1": "None",
"upscaler_2": "None",
"extras_upscaler_2_visibility": 0,
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
+ "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
}
def test_simple_upscaling_performed(self):
@@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase):
def setUp(self):
self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
self.png_info = {
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
+ "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
}
def test_png_info_performed(self):
@@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase):
def setUp(self):
self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
self.interrogate = {
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")),
+ "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))),
"model": "clip"
}
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
index 08c5c903..5240ec36 100644
--- a/test/basic_features/img2img_test.py
+++ b/test/basic_features/img2img_test.py
@@ -1,14 +1,16 @@
+import os
import unittest
import requests
from gradio.processing_utils import encode_pil_to_base64
from PIL import Image
+from modules.paths import script_path
class TestImg2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
self.simple_img2img = {
- "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
+ "init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))],
"resize_mode": 0,
"denoising_strength": 0.75,
"mask": None,
@@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
+ self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_with_inverted_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
+ self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.simple_img2img["inpainting_mask_invert"] = True
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
diff --git a/test/server_poll.py b/test/server_poll.py
index 42d56a4c..c732630f 100644
--- a/test/server_poll.py
+++ b/test/server_poll.py
@@ -1,6 +1,8 @@
import unittest
import requests
import time
+import os
+from modules.paths import script_path
def run_tests(proc, test_dir):
@@ -15,8 +17,8 @@ def run_tests(proc, test_dir):
break
if proc.poll() is None:
if test_dir is None:
- test_dir = "test"
- suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
+ test_dir = os.path.join(script_path, "test")
+ suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir)
result = unittest.TextTestRunner(verbosity=2).run(suite)
return len(result.failures) + len(result.errors)
else:
diff --git a/webui.py b/webui.py
index 32561877..aaec79fd 100644
--- a/webui.py
+++ b/webui.py
@@ -183,13 +183,16 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler)
-def setup_cors(app):
+def setup_middleware(app):
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
def create_api(app):
@@ -213,8 +216,7 @@ def api_only():
initialize()
app = FastAPI()
- setup_cors(app)
- app.add_middleware(GZipMiddleware, minimum_size=1000)
+ setup_middleware(app)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
@@ -271,9 +273,7 @@ def webui():
# running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
- setup_cors(app)
-
- app.add_middleware(GZipMiddleware, minimum_size=1000)
+ setup_middleware(app)
modules.progress.setup_progress_api(app)
@@ -290,24 +290,35 @@ def webui():
wait_on_server(shared.demo)
print('Restarting UI...')
+ startup_timer.reset()
+
sd_samplers.set_samplers()
modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions()
+ startup_timer.record("list extensions")
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
+ startup_timer.record("load scripts")
+
modules.script_callbacks.model_loaded_callback(shared.sd_model)
+ startup_timer.record("model loaded callback")
+
modelloader.load_upscalers()
+ startup_timer.record("load upscalers")
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
+ startup_timer.record("reload script modules")
modules.sd_models.list_models()
+ startup_timer.record("list SD models")
shared.reload_hypernetworks()
+ startup_timer.record("reload hypernetworks")
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
@@ -316,6 +327,7 @@ def webui():
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+ startup_timer.record("initialize extra networks")
if __name__ == "__main__":