aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorInvincibleDude <81354513+InvincibleDude@users.noreply.github.com>2023-03-14 16:55:59 +0300
committerGitHub <noreply@github.com>2023-03-14 16:55:59 +0300
commitf5e44364535ccc7efef445dacb6395c3942f2f17 (patch)
treeca7384e6225fccdae1e89db1e0ccc44dd60c7311
parentf6e27378404631d951656388fc178b784fe1495b (diff)
parenta9fed7c364061ae6efb37f797b6b522cb3cf7aa2 (diff)
Merge branch 'master' into improved-hr-conflict-test
-rw-r--r--README.md1
-rw-r--r--extensions-builtin/Lora/lora.py21
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py14
-rw-r--r--html/extra-networks-card.html3
-rw-r--r--html/licenses.html219
-rw-r--r--javascript/extraNetworks.js42
-rw-r--r--javascript/hints.js1
-rw-r--r--javascript/imageviewer.js2
-rw-r--r--javascript/notification.js2
-rw-r--r--javascript/progressbar.js2
-rw-r--r--launch.py44
-rw-r--r--modules/api/api.py128
-rw-r--r--modules/api/models.py30
-rw-r--r--modules/codeformer_model.py2
-rw-r--r--modules/extensions.py6
-rw-r--r--modules/generation_parameters_copypaste.py13
-rw-r--r--modules/images.py7
-rw-r--r--modules/mac_specific.py3
-rw-r--r--modules/memmon.py12
-rw-r--r--modules/modelloader.py8
-rw-r--r--modules/models/diffusion/uni_pc/__init__.py1
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py100
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py857
-rw-r--r--modules/processing.py11
-rw-r--r--modules/script_callbacks.py8
-rw-r--r--modules/scripts.py32
-rw-r--r--modules/sd_hijack.py12
-rw-r--r--modules/sd_hijack_optimizations.py70
-rw-r--r--modules/sd_models.py24
-rw-r--r--modules/sd_samplers.py2
-rw-r--r--modules/sd_samplers_compvis.py59
-rw-r--r--modules/sd_samplers_kdiffusion.py4
-rw-r--r--modules/sd_vae_approx.py5
-rw-r--r--modules/shared.py26
-rw-r--r--modules/timer.py3
-rw-r--r--modules/ui.py14
-rw-r--r--modules/ui_common.py9
-rw-r--r--modules/ui_extensions.py2
-rw-r--r--modules/ui_extra_networks.py43
-rw-r--r--modules/ui_extra_networks_checkpoints.py14
-rw-r--r--modules/ui_extra_networks_hypernets.py12
-rw-r--r--modules/ui_extra_networks_textual_inversion.py13
-rw-r--r--requirements_versions.txt4
-rw-r--r--scripts/prompt_matrix.py2
-rw-r--r--scripts/xyz_grid.py155
-rw-r--r--style.css78
-rw-r--r--test/basic_features/extras_test.py8
-rw-r--r--test/basic_features/img2img_test.py8
-rw-r--r--test/basic_features/txt2img_test.py2
-rw-r--r--test/server_poll.py6
-rw-r--r--webui.py75
51 files changed, 2007 insertions, 212 deletions
diff --git a/README.md b/README.md
index 2ceb4d2d..24f8e799 100644
--- a/README.md
+++ b/README.md
@@ -157,5 +157,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK
+- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index cb8f1d36..8937b585 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -3,7 +3,9 @@ import os
import re
import torch
-from modules import shared, devices, sd_models
+from modules import shared, devices, sd_models, errors
+
+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
@@ -43,6 +45,23 @@ class LoraOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
+ self.metadata = {}
+
+ _, ext = os.path.splitext(filename)
+ if ext.lower() == ".safetensors":
+ try:
+ self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ except Exception as e:
+ errors.display(e, f"reading lora {filename}")
+
+ if self.metadata:
+ m = {}
+ for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
+ m[k] = v
+
+ self.metadata = m
+
+ self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
class LoraModule:
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 22cabcb0..68b11332 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -15,21 +15,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def list_items(self):
for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
-
yield {
"name": name,
"filename": path,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": path + ".png",
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
}
def allowed_directories_for_previews(self):
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index 8a5e2fbd..1bf3fc30 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -1,4 +1,6 @@
<div class='card' {preview_html} onclick={card_clicked}>
+ {metadata_button}
+
<div class='actions'>
<div class='additional'>
<ul>
@@ -7,6 +9,7 @@
<span style="display:none" class='search_term'>{search_term}</span>
</div>
<span class='name'>{name}</span>
+ <span class='description'>{description}</span>
</div>
</div>
diff --git a/html/licenses.html b/html/licenses.html
index 570630eb..bddbf466 100644
--- a/html/licenses.html
+++ b/html/licenses.html
@@ -417,3 +417,222 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
+<h2><a href="https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/LICENSE">Scaled Dot Product Attention</a></h2>
+<small>Some small amounts of code borrowed and reworked.</small>
+<pre>
+ Copyright 2023 The HuggingFace Team. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+</pre> \ No newline at end of file
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 17bf2000..2fb87cd5 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
- var close = gradioApp().getElementById(tabname+'_extra_close')
search.classList.add('search')
tabs.appendChild(search)
tabs.appendChild(refresh)
- tabs.appendChild(close)
search.addEventListener("input", function(evt){
searchTerm = search.value.toLowerCase()
@@ -78,7 +76,7 @@ function cardClicked(tabname, textToAdd, allowNegativePrompt){
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
- textarea.value = textarea.value + " " + textToAdd
+ textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
}
updateInput(textarea)
@@ -104,4 +102,40 @@ function extraNetworksSearchButton(tabs_id, event){
searchTextarea.value = text
updateInput(searchTextarea)
-} \ No newline at end of file
+}
+
+var globalPopup = null;
+var globalPopupInner = null;
+function popup(contents){
+ if(! globalPopup){
+ globalPopup = document.createElement('div')
+ globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
+ globalPopup.classList.add('global-popup');
+
+ var close = document.createElement('div')
+ close.classList.add('global-popup-close');
+ close.onclick = function(){ globalPopup.style.display = "none"; };
+ close.title = "Close";
+ globalPopup.appendChild(close)
+
+ globalPopupInner = document.createElement('div')
+ globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
+ globalPopupInner.classList.add('global-popup-inner');
+ globalPopup.appendChild(globalPopupInner)
+
+ gradioApp().appendChild(globalPopup);
+ }
+
+ globalPopupInner.innerHTML = '';
+ globalPopupInner.appendChild(contents);
+
+ globalPopup.style.display = "flex";
+}
+
+function extraNetworksShowMetadata(text){
+ elem = document.createElement('pre')
+ elem.classList.add('popup-metadata');
+ elem.textContent = text;
+
+ popup(elem);
+}
diff --git a/javascript/hints.js b/javascript/hints.js
index f1199009..7f4101b2 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -6,6 +6,7 @@ titles = {
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
+ "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index aac2ee82..28e748b7 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -11,7 +11,7 @@ function showModal(event) {
if (modalImage.style.display === 'none') {
lb.style.setProperty('background-image', 'url(' + source.src + ')');
}
- lb.style.display = "block";
+ lb.style.display = "flex";
lb.focus()
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
diff --git a/javascript/notification.js b/javascript/notification.js
index 040a3afa..5ae6df24 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -15,7 +15,7 @@ onUiUpdate(function(){
}
}
- const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
+ const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
if (galleryPreviews == null) return;
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index ff6d757b..9ccc9da4 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -139,7 +139,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var divProgress = document.createElement('div')
divProgress.className='progressDiv'
- divProgress.style.display = opts.show_progressbar ? "" : "none"
+ divProgress.style.display = opts.show_progressbar ? "block" : "none"
var divInner = document.createElement('div')
divInner.className='progress'
diff --git a/launch.py b/launch.py
index a68bb3a9..b943fed2 100644
--- a/launch.py
+++ b/launch.py
@@ -8,6 +8,14 @@ import platform
import argparse
import json
+parser = argparse.ArgumentParser(add_help=False)
+parser.add_argument("--ui-settings-file", type=str, default='config.json')
+parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
+args, _ = parser.parse_known_args(sys.argv)
+
+script_path = os.path.dirname(__file__)
+data_path = os.getcwd()
+
dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable
@@ -122,7 +130,7 @@ def is_installed(package):
def repo_dir(name):
- return os.path.join(dir_repos, name)
+ return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None):
@@ -161,7 +169,17 @@ def git_clone(url, dir, name, commithash=None):
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
-
+
+def git_pull_recursive(dir):
+ for subdir, _, _ in os.walk(dir):
+ if os.path.exists(os.path.join(subdir, '.git')):
+ try:
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
+ except subprocess.CalledProcessError as e:
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
+
+
def version_check(commit):
try:
import requests
@@ -205,7 +223,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', []))
- return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
+ return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
def run_extensions_installers(settings_file):
@@ -242,11 +260,8 @@ def prepare_environment():
sys.argv += shlex.split(commandline_args)
- parser = argparse.ArgumentParser(add_help=False)
- parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
- args, _ = parser.parse_known_args(sys.argv)
-
sys.argv, _ = extract_arg(sys.argv, '-f')
+ sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
@@ -295,7 +310,7 @@ def prepare_environment():
if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok")
- os.makedirs(dir_repos, exist_ok=True)
+ os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
@@ -304,14 +319,19 @@ def prepare_environment():
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
if not is_installed("lpips"):
- run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
+ run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
- run_pip(f"install -r {requirements_file}", "requirements for Web UI")
+ if not os.path.isfile(requirements_file):
+ requirements_file = os.path.join(script_path, requirements_file)
+ run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
run_extensions_installers(settings_file=args.ui_settings_file)
if update_check:
version_check(commit)
+
+ if update_all_extensions:
+ git_pull_recursive(os.path.join(data_path, dir_extensions))
if "--exit" in sys.argv:
print("Exiting because of --exit argument")
@@ -327,7 +347,7 @@ def tests(test_dir):
sys.argv.append("--api")
if "--ckpt" not in sys.argv:
sys.argv.append("--ckpt")
- sys.argv.append("./test/test_files/empty.pt")
+ sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
@@ -336,7 +356,7 @@ def tests(test_dir):
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
os.environ['COMMANDLINE_ARGS'] = ""
- with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
+ with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
import test.server_poll
diff --git a/modules/api/api.py b/modules/api/api.py
index 5a9ac5f1..35e17afc 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -150,6 +150,7 @@ class Api:
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
@@ -163,47 +164,98 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
- def get_script(self, script_name, script_runner):
- if script_name is None:
+ def get_selectable_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
return None, None
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
-
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx]
return script, script_idx
+
+ def get_scripts_list(self):
+ t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
+ i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
+
+ return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
+
+ def get_script(self, script_name, script_runner):
+ if script_name is None or script_name == "":
+ return None, None
+
+ script_idx = script_name_to_index(script_name, script_runner.scripts)
+ return script_runner.scripts[script_idx]
+
+ def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
+ #find max idx from the scripts in runner and generate a none array to init script_args
+ last_arg_index = 1
+ for script in script_runner.scripts:
+ if last_arg_index < script.args_to:
+ last_arg_index = script.args_to
+ # None everywhere except position 0 to initialize script args
+ script_args = [None]*last_arg_index
+ # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
+ if selectable_scripts:
+ script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
+ script_args[0] = selectable_idx + 1
+ else:
+ # when [0] = 0 no selectable script to run
+ script_args[0] = 0
+
+ # Now check for always on scripts
+ if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
+ for alwayson_script_name in request.alwayson_scripts.keys():
+ alwayson_script = self.get_script(alwayson_script_name, script_runner)
+ if alwayson_script == None:
+ raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
+ # Selectable script in always on script param check
+ if alwayson_script.alwayson == False:
+ raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
+ # always on script with no arg should always run so you don't really need to add them to the requests
+ if "args" in request.alwayson_scripts[alwayson_script_name]:
+ script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
+ return script_args
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
- script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
+ script_runner = scripts.scripts_txt2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
- populate = txt2imgreq.copy(update={ # Override __init__ params
+ populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True
- }
- )
+ "do_not_save_samples": not txt2imgreq.save_images,
+ "do_not_save_grid": not txt2imgreq.save_images,
+ })
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = vars(populate)
args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin()
- if script is not None:
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args)
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
- b64images = list(map(encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -212,41 +264,53 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
- script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
-
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
- populate = img2imgreq.copy(update={ # Override __init__ params
+ script_runner = scripts.scripts_img2img
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(True)
+ ui.create_ui()
+ selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True,
- "mask": mask
- }
- )
+ "do_not_save_samples": not img2imgreq.save_images,
+ "do_not_save_grid": not img2imgreq.save_images,
+ "mask": mask,
+ })
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
args.pop('script_name', None)
+ args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
+ args.pop('alwayson_scripts', None)
+
+ script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
+
+ send_images = args.pop('send_images', True)
+ args.pop('save_images', None)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin()
- if script is not None:
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args)
+ if selectable_scripts != None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end()
- b64images = list(map(encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
if not img2imgreq.include_init_images:
img2imgreq.init_images = None
diff --git a/modules/api/models.py b/modules/api/models.py
index cba43d3b..4a70f440 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
"outpath_samples",
"outpath_grids",
"sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
+ # "do_not_save_samples",
+ # "do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
@@ -100,13 +100,31 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
+ [
+ {"key": "sampler_index", "type": str, "default": "Euler"},
+ {"key": "init_images", "type": list, "default": None},
+ {"key": "denoising_strength", "type": float, "default": 0.75},
+ {"key": "mask", "type": str, "default": None},
+ {"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
+ {"key": "script_name", "type": str, "default": None},
+ {"key": "script_args", "type": list, "default": []},
+ {"key": "send_images", "type": bool, "default": True},
+ {"key": "save_images", "type": bool, "default": False},
+ {"key": "alwayson_scripts", "type": dict, "default": {}},
+ ]
).generate_model()
class TextToImageResponse(BaseModel):
@@ -267,3 +285,7 @@ class EmbeddingsResponse(BaseModel):
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
+
+class ScriptsList(BaseModel):
+ txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
+ img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 01fb7bd8..8d84bbc9 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -55,7 +55,7 @@ def setup_model(dirname):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
diff --git a/modules/extensions.py b/modules/extensions.py
index 3eef9eaf..ed4b58fe 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -66,7 +66,7 @@ class Extension:
def check_updates(self):
repo = git.Repo(self.path)
- for fetch in repo.remote().fetch("--dry-run"):
+ for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
@@ -79,8 +79,8 @@ class Extension:
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
- repo.git.fetch('--all')
- repo.git.reset('--hard', 'origin')
+ repo.git.fetch(all=True)
+ repo.git.reset('origin', hard=True)
def list_extensions():
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 221e3d44..e8bb49f5 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -23,13 +23,14 @@ registered_param_bindings = []
class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
+ self.paste_field_names = paste_field_names
def reset():
@@ -134,7 +135,7 @@ def connect_paste_params_buttons():
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
if binding.source_tabname is not None and fields is not None:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
binding.paste_button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
@@ -292,6 +293,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
settings_map = {}
+
+
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
@@ -300,7 +303,11 @@ infotext_to_setting_name_mapping = [
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
- ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
+ ('UniPC variant', 'uni_pc_variant'),
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ('UniPC order', 'uni_pc_order'),
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
]
diff --git a/modules/images.py b/modules/images.py
index 5b80c23e..2da988ee 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif image_to_save.mode == 'I;16':
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
if opts.enable_pnginfo and info is not None:
exif_bytes = piexif.dump({
@@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename)
+ if hasattr(os, 'statvfs'):
+ max_name_len = os.statvfs(path).f_namemax
+ fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
+ params.filename = fullfn_without_extension + extension
+ fullfn = params.filename
_atomically_save_image(image, fullfn_without_extension, extension)
image.already_saved_as = fullfn
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index ddcea53b..18e6ff72 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
@@ -45,7 +45,6 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
diff --git a/modules/memmon.py b/modules/memmon.py
index a7060f58..4018edcc 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
self.data = defaultdict(int)
try:
- torch.cuda.mem_get_info()
+ self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True
+ def cuda_mem_get_info(self):
+ index = self.device.index if self.device.index is not None else torch.cuda.current_device()
+ return torch.cuda.mem_get_info(index)
+
def run(self):
if self.disabled:
return
@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.clear()
continue
- self.data["min_free"] = torch.cuda.mem_get_info()[0]
+ self.data["min_free"] = self.cuda_mem_get_info()[0]
while self.run_flag.is_set():
- free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
+ free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate)
@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
def read(self):
if not self.disabled:
- free, total = torch.cuda.mem_get_info()
+ free, total = self.cuda_mem_get_info()
self.data["free"] = free
self.data["total"] = total
diff --git a/modules/modelloader.py b/modules/modelloader.py
index fc3f6249..e351d808 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -6,7 +6,7 @@ from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared
-from modules.upscaler import Upscaler
+from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -169,4 +169,8 @@ def load_upscalers():
scaler = cls(commandline_options.get(cmd_name, None))
datas += scaler.scalers
- shared.sd_upscalers = datas
+ shared.sd_upscalers = sorted(
+ datas,
+ # Special case for UpscalerNone keeps it at the beginning of the list.
+ key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
+ )
diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py
new file mode 100644
index 00000000..e1265e3f
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/__init__.py
@@ -0,0 +1 @@
+from .sampler import UniPCSampler
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
new file mode 100644
index 00000000..a241c8a7
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -0,0 +1,100 @@
+"""SAMPLING ONLY."""
+
+import torch
+
+from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
+from modules import shared, devices
+
+
+class UniPCSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.before_sample = None
+ self.after_sample = None
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != devices.device:
+ attr = attr.to(devices.device)
+ setattr(self, name, attr)
+
+ def set_hooks(self, before_sample, after_sample, after_update):
+ self.before_sample = before_sample
+ self.after_sample = after_sample
+ self.after_update = after_update
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ # print(f'Data shape for UniPC sampling is {size}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ # SD 1.X is "noise", SD 2.X is "v"
+ model_type = "v" if self.model.parameterization == "v" else "noise"
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=model_type,
+ guidance_type="classifier-free",
+ #condition=conditioning,
+ #unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
+ x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
+
+ return x.to(device), None
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
new file mode 100644
index 00000000..eb5f4e76
--- /dev/null
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -0,0 +1,857 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm.auto import trange
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0**2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ #condition=None,
+ #unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, None, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input, condition):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous, condition, unconditional_condition):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input, condition)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ if isinstance(condition, dict):
+ assert isinstance(unconditional_condition, dict)
+ c_in = dict()
+ for k in condition:
+ if isinstance(condition[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_condition[k][i],
+ condition[k][i]]) for i in range(len(condition[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_condition[k],
+ condition[k]])
+ elif isinstance(condition, list):
+ c_in = list()
+ assert isinstance(unconditional_condition, list)
+ for i in range(len(condition)):
+ c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
+ else:
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class UniPC:
+ def __init__(
+ self,
+ model_fn,
+ noise_schedule,
+ predict_x0=True,
+ thresholding=False,
+ max_val=1.,
+ variant='bh1',
+ condition=None,
+ unconditional_condition=None,
+ before_sample=None,
+ after_sample=None,
+ after_update=None
+ ):
+ """Construct a UniPC.
+
+ We support both data_prediction and noise_prediction.
+ """
+ self.model_fn_ = model_fn
+ self.noise_schedule = noise_schedule
+ self.variant = variant
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+ self.condition = condition
+ self.unconditional_condition = unconditional_condition
+ self.before_sample = before_sample
+ self.after_sample = after_sample
+ self.after_update = after_update
+
+ def dynamic_thresholding_fn(self, x0, t=None):
+ """
+ The dynamic thresholding method.
+ """
+ dims = x0.dim()
+ p = self.dynamic_thresholding_ratio
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model(self, x, t):
+ cond = self.condition
+ uncond = self.unconditional_condition
+ if self.before_sample is not None:
+ x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
+ res = self.model_fn_(x, t, cond, uncond)
+ if self.after_sample is not None:
+ x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
+
+ if isinstance(res, tuple):
+ # (None, pred_x0)
+ res = res[1]
+
+ return res
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3,] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3,] * (K - 1) + [1]
+ else:
+ orders = [3,] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2,] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2,] * (K - 1) + [1]
+ elif order == 1:
+ K = steps
+ orders = [1,] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
+ if len(t.shape) == 0:
+ t = t.view(-1)
+ if 'bh' in self.variant:
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
+ else:
+ assert self.variant == 'vary_coeff'
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
+
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
+ #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
+ ns = self.noise_schedule
+ assert order <= len(model_prev_list)
+
+ # first compute rks
+ t_prev_0 = t_prev_list[-1]
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
+ lambda_t = ns.marginal_lambda(t)
+ model_prev_0 = model_prev_list[-1]
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h = lambda_t - lambda_prev_0
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ t_prev_i = t_prev_list[-(i + 1)]
+ model_prev_i = model_prev_list[-(i + 1)]
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
+ rk = (lambda_prev_i - lambda_prev_0) / h
+ rks.append(rk)
+ D1s.append((model_prev_i - model_prev_0) / rk)
+
+ rks.append(1.)
+ rks = torch.tensor(rks, device=x.device)
+
+ K = len(rks)
+ # build C matrix
+ C = []
+
+ col = torch.ones_like(rks)
+ for k in range(1, K + 1):
+ C.append(col)
+ col = col * rks / (k + 1)
+ C = torch.stack(C, dim=1)
+
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
+ A_p = C_inv_p
+
+ if use_corrector:
+ #print('using corrector')
+ C_inv = torch.linalg.inv(C)
+ A_c = C_inv
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh)
+ h_phi_ks = []
+ factorial_k = 1
+ h_phi_k = h_phi_1
+ for k in range(1, K + 2):
+ h_phi_ks.append(h_phi_k)
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
+ factorial_k *= (k + 1)
+
+ model_t = None
+ if self.predict_x0:
+ x_t_ = (
+ sigma_t / sigma_prev_0 * x
+ - alpha_t * h_phi_1 * model_prev_0
+ )
+ # now predictor
+ x_t = x_t_
+ if len(D1s) > 0:
+ # compute the residuals for predictor
+ for k in range(K - 1):
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
+ # now corrector
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_
+ k = 0
+ for k in range(K - 1):
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
+ else:
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ x_t_ = (
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
+ - (sigma_t * h_phi_1) * model_prev_0
+ )
+ # now predictor
+ x_t = x_t_
+ if len(D1s) > 0:
+ # compute the residuals for predictor
+ for k in range(K - 1):
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
+ # now corrector
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_
+ k = 0
+ for k in range(K - 1):
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
+ return x_t, model_t
+
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
+ #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
+ ns = self.noise_schedule
+ assert order <= len(model_prev_list)
+ dims = x.dim()
+
+ # first compute rks
+ t_prev_0 = t_prev_list[-1]
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
+ lambda_t = ns.marginal_lambda(t)
+ model_prev_0 = model_prev_list[-1]
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h = lambda_t - lambda_prev_0
+
+ rks = []
+ D1s = []
+ for i in range(1, order):
+ t_prev_i = t_prev_list[-(i + 1)]
+ model_prev_i = model_prev_list[-(i + 1)]
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
+ rks.append(rk)
+ D1s.append((model_prev_i - model_prev_0) / rk)
+
+ rks.append(1.)
+ rks = torch.tensor(rks, device=x.device)
+
+ R = []
+ b = []
+
+ hh = -h[0] if self.predict_x0 else h[0]
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.variant == 'bh1':
+ B_h = hh
+ elif self.variant == 'bh2':
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= (i + 1)
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=x.device)
+
+ # now predictor
+ use_predictor = len(D1s) > 0 and x_t is None
+ if len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1) # (B, K)
+ if x_t is None:
+ # for order 2, we use a simplified version
+ if order == 2:
+ rhos_p = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
+ else:
+ D1s = None
+
+ if use_corrector:
+ #print('using corrector')
+ # for order 1, we use a simplified version
+ if order == 1:
+ rhos_c = torch.tensor([0.5], device=b.device)
+ else:
+ rhos_c = torch.linalg.solve(R, b)
+
+ model_t = None
+ if self.predict_x0:
+ x_t_ = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
+ )
+
+ if x_t is None:
+ if use_predictor:
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
+
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ if D1s is not None:
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
+ else:
+ x_t_ = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
+ )
+ if x_t is None:
+ if use_predictor:
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
+
+ if use_corrector:
+ model_t = self.model_fn(x_t, t)
+ if D1s is not None:
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = (model_t - model_prev_0)
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
+ return x_t, model_t
+
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05, corrector=False,
+ ):
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'multistep':
+ assert steps >= order, "UniPC order must be < sampling steps"
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in range(1, order):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ model_prev_list.append(model_x)
+ t_prev_list.append(vec_t)
+ for step in trange(order, steps + 1):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ #print('this step order:', step_order)
+ if step == steps:
+ #print('do not run corrector at the last step')
+ use_corrector = False
+ else:
+ use_corrector = True
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ model_prev_list[-1] = model_x
+ else:
+ raise NotImplementedError()
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,)*(dims - 1)]
diff --git a/modules/processing.py b/modules/processing.py
index 694e2637..0a46174c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -597,6 +597,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
+ extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -620,6 +621,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ if p.scripts is not None:
+ p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
if len(prompts) == 0:
break
@@ -753,7 +757,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks:
+ if not p.disable_extra_networks and extra_network_data:
extra_networks.deactivate(p, extra_network_data)
devices.torch_gc()
@@ -944,7 +948,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ img2img_sampler_name = self.sampler_name
+
+ if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+ img2img_sampler_name = 'DDIM'
if self.hr_sampler == '---':
pass
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index edd0e2a7..07911876 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -29,7 +29,7 @@ class ImageSaveParams:
class CFGDenoiserParams:
- def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -44,6 +44,12 @@ class CFGDenoiserParams:
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
+
+ self.text_cond = text_cond
+ """ Encoder hidden states of text conditioning from prompt"""
+
+ self.text_uncond = text_uncond
+ """ Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams:
diff --git a/modules/scripts.py b/modules/scripts.py
index 24056a12..8de19884 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -33,6 +33,11 @@ class Script:
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
"""
+ paste_field_names = None
+ """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
+ various "Send to <X>" buttons when clicked
+ """
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -80,6 +85,20 @@ class Script:
pass
+ def before_process_batch(self, p, *args, **kwargs):
+ """
+ Called before extra networks are parsed from the prompt, so you can add
+ new extra network keywords to the prompt with this callback.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ """
+
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -256,6 +275,7 @@ class ScriptRunner:
self.alwayson_scripts = []
self.titles = []
self.infotext_fields = []
+ self.paste_field_names = []
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -304,6 +324,9 @@ class ScriptRunner:
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
+ if script.paste_field_names is not None:
+ self.paste_field_names += script.paste_field_names
+
inputs += controls
inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
@@ -388,6 +411,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def before_process_batch(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_process_batch(p, *script_args, **kwargs)
+ except Exception:
+ print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 79476783..f4bb0266 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -37,11 +37,23 @@ def apply_optimizations():
optimization_method = None
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
+ elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
+ optimization_method = 'sdp-no-mem'
+ elif cmd_opts.opt_sdp_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
+ optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index c02d954c..2e307b5d 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
+# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
+# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, inner_dim = x.shape
+
+ if mask is not None:
+ mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
+ mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
+
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ head_dim = inner_dim // h
+ q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+
+ del q_in, k_in, v_in
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
+ hidden_states = hidden_states.to(dtype)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ return scaled_dot_product_attention_forward(self, x, context, mask)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x):
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+def sdp_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
+ out = out.to(dtype)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
+
+def sdp_no_mem_attnblock_forward(self, x):
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ return sdp_attnblock_forward(self, x)
+
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 93959f55..f0cb1240 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
+def read_metadata_from_safetensors(filename):
+ import json
+
+ with open(filename, mode="rb") as file:
+ metadata_len = file.read(8)
+ metadata_len = int.from_bytes(metadata_len, "little")
+ json_start = file.read(2)
+
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
+ json_data = json_start + file.read(metadata_len-2)
+ json_obj = json.loads(json_data)
+
+ res = {}
+ for k, v in json_obj.get("__metadata__", {}).items():
+ res[k] = v
+ if isinstance(v, str) and v[0:1] == '{':
+ try:
+ res[k] = json.loads(v)
+ except Exception as e:
+ pass
+
+ return res
+
+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 28c2136f..ff361f22 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -32,7 +32,7 @@ def set_samplers():
global samplers, samplers_for_img2img
hidden = set(shared.opts.hide_samplers)
- hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
+ hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index d03131cd..083da18c 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -7,19 +7,27 @@ import torch
from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared
+import modules.models.diffusion.uni_pc
samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+ sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
]
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
+ self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
+ self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
+ self.orig_p_sample_ddim = None
+ if self.is_plms:
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms
+ elif self.is_ddim:
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
@@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler:
return self.last_latent
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
+
+ return res
+
+ def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler:
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
- x_dec = img_orig * self.mask + self.nmask * x_dec
+ x = img_orig * self.mask + self.nmask * x
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
@@ -84,12 +101,13 @@ class VanillaStableDiffusionSampler:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+ return x, ts, cond, unconditional_conditioning
+ def update_step(self, last_latent):
if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
+ self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else:
- self.last_latent = res[1]
+ self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent)
@@ -97,26 +115,51 @@ class VanillaStableDiffusionSampler:
state.sampling_step = self.step
shared.total_tqdm.update()
- return res
+ def after_sample(self, x, ts, cond, uncond, res):
+ if not self.is_unipc:
+ self.update_step(res[1])
+
+ return x, ts, cond, uncond, res
+
+ def unipc_after_update(self, x, model_x):
+ self.update_step(x)
def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
+ if self.is_unipc:
+ keys = [
+ ('UniPC variant', 'uni_pc_variant'),
+ ('UniPC skip type', 'uni_pc_skip_type'),
+ ('UniPC order', 'uni_pc_order'),
+ ('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ]
+
+ for name, key in keys:
+ v = getattr(shared.opts, key)
+ if v != shared.opts.get_default(key):
+ p.extra_generation_params[name] = v
+
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
+ if self.is_unipc:
+ self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
+ if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
+ num_steps = shared.opts.uni_pc_order
valid_step = 999 / (1000 // num_steps)
if valid_step == math.floor(valid_step):
return int(valid_step) + 1
-
+
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 528f513f..93f0e55a 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
+ tensor = denoiser_params.text_cond
+ uncond = denoiser_params.text_uncond
if tensor.shape[1] == uncond.shape[1]:
if not is_edit_model:
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 0027343a..e2f00468 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -35,8 +35,11 @@ def model():
global sd_vae_approx_model
if sd_vae_approx_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
+ sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
diff --git a/modules/shared.py b/modules/shared.py
index 805f9cc1..f28a12cc 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -69,6 +69,8 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
+parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
@@ -114,7 +116,10 @@ parser.add_argument("--no-download-sd-model", action='store_true', help="don't d
script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
-cmd_opts = parser.parse_args()
+if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
+ cmd_opts = parser.parse_args()
+else:
+ cmd_opts, _ = parser.parse_known_args()
restricted_opts = {
"samples_filename_pattern",
@@ -305,6 +310,7 @@ def list_samplers():
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
+tab_names = []
options_templates = {}
@@ -327,9 +333,11 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
+ "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
+ "img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
@@ -440,6 +448,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
}))
@@ -460,6 +469,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
@@ -485,6 +495,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
+ 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
+ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
+ 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
+ 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
}))
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
@@ -559,6 +573,15 @@ class Options:
return True
+ def get_default(self, key):
+ """returns the default value for the key"""
+
+ data_label = self.data_labels.get(key)
+ if data_label is None:
+ return None
+
+ return data_label.default
+
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
@@ -691,6 +714,7 @@ class TotalTQDM:
def clear(self):
if self._tqdm is not None:
+ self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
diff --git a/modules/timer.py b/modules/timer.py
index 57a4f17a..ba92be33 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -33,3 +33,6 @@ class Timer:
res += ")"
return res
+
+ def reset(self):
+ self.__init__()
diff --git a/modules/ui.py b/modules/ui.py
index 9051a5ec..48a374c6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -957,7 +957,7 @@ def create_ui():
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
@@ -1581,6 +1581,10 @@ def create_ui():
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
+ shared.tab_names = []
+ for _interface, label, _ifid in interfaces:
+ shared.tab_names.append(label)
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
@@ -1591,6 +1595,8 @@ def create_ui():
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
+ if label in shared.opts.hidden_tabs:
+ continue
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
@@ -1763,7 +1769,8 @@ def create_ui():
def reload_javascript():
- head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
+ script_js = os.path.join(script_path, "script.js")
+ head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
@@ -1772,6 +1779,9 @@ def reload_javascript():
for script in modules.scripts.list_scripts("javascript", ".js"):
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
+ for script in modules.scripts.list_scripts("javascript", ".mjs"):
+ head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
+
head += f'<script type="text/javascript">{inline}</script>\n'
def template_response(*args, **kwargs):
diff --git a/modules/ui_common.py b/modules/ui_common.py
index fd047f31..a12433d2 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -198,9 +198,16 @@ Requested path was: {f}
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+ paste_field_names = []
+ if tabname == "txt2img":
+ paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
+ elif tabname == "img2img":
+ paste_field_names = modules.scripts.scripts_img2img.paste_field_names
+
for paste_tabname, paste_button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
+ paste_field_names=paste_field_names
))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index bd4308ef..df75a925 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -304,7 +304,7 @@ def create_ui():
with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 71f1d81f..cdfd6f2a 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -30,8 +30,8 @@ def add_pages_to_demo(app):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
+ if ext not in (".png", ".jpg", ".webp"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -124,19 +124,56 @@ class ExtraNetworksPage:
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ metadata_button = ""
+ metadata = item.get("metadata")
+ if metadata:
+ metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
+ metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
+
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
+ "description": (item.get("description") or ""),
"card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
+ "metadata_button": metadata_button,
}
return self.card_page.format(**args)
+ def find_preview(self, path):
+ """
+ Find a preview PNG for a given path (without extension) and call link_preview on it.
+ """
+
+ preview_extensions = ["png", "jpg", "webp"]
+ if shared.opts.samples_format not in preview_extensions:
+ preview_extensions.append(shared.opts.samples_format)
+
+ potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
+
+ for file in potential_files:
+ if os.path.isfile(file):
+ return self.link_preview(file)
+
+ return None
+
+ def find_description(self, path):
+ """
+ Find and read a description file for a given path (without extension).
+ """
+ for file in [f"{path}.txt", f"{path}.description.txt"]:
+ try:
+ with open(file, "r", encoding="utf-8", errors="replace") as f:
+ return f.read()
+ except OSError:
+ pass
+ return None
+
def intialize():
extra_pages.clear()
@@ -183,7 +220,6 @@ def create_ui(container, button, tabname):
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
- button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
@@ -194,7 +230,6 @@ def create_ui(container, button, tabname):
state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
- button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh():
res = []
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 04097a79..a17aa9c9 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -1,7 +1,6 @@
import html
import json
import os
-import urllib.parse
from modules import shared, ui_extra_networks, sd_models
@@ -17,21 +16,14 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items():
path, ext = os.path.splitext(checkpoint.filename)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
-
yield {
"name": checkpoint.name_for_extra,
"filename": path,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
- "local_preview": path + ".png",
+ "local_preview": f"{path}.{shared.opts.samples_format}",
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 57851088..6187e000 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -14,21 +14,15 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def list_items(self):
for name, path in shared.hypernetworks.items():
path, ext = os.path.splitext(path)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
yield {
"name": name,
"filename": path,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": path + ".png",
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index bb64eb81..6944d559 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,7 +1,7 @@
import json
import os
-from modules import ui_extra_networks, sd_hijack
+from modules import ui_extra_networks, sd_hijack, shared
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
@@ -15,19 +15,14 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def list_items(self):
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
path, ext = os.path.splitext(embedding.filename)
- preview_file = path + ".preview.png"
-
- preview = None
- if os.path.isfile(preview_file):
- preview = self.link_preview(preview_file)
-
yield {
"name": embedding.name,
"filename": embedding.filename,
- "preview": preview,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name),
- "local_preview": path + ".preview.png",
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
}
def allowed_directories_for_previews(self):
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 331d0fe8..0031c616 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -23,8 +23,8 @@ torchdiffeq==0.2.3
kornia==0.6.7
lark==1.1.2
inflection==0.5.1
-GitPython==3.1.27
+GitPython==3.1.30
torchsde==0.2.5
safetensors==0.2.7
httpcore<=0.15
-fastapi==0.90.1
+fastapi==0.94.0
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index b1c486d4..e9b11517 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -100,7 +100,7 @@ class Script(scripts.Script):
processed = process_images(p)
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
- grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size)
+ grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid)
processed.index_of_first_image = 1
processed.infotexts.insert(0, processed.infotexts[0])
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 53511b12..ce584981 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -128,6 +128,24 @@ def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
p.styles.extend(x.split(','))
+def apply_uni_pc_order(p, x, xs):
+ opts.data["uni_pc_order"] = min(x, p.steps - 1)
+
+
+def apply_face_restore(p, opt, x):
+ opt = opt.lower()
+ if opt == 'codeformer':
+ is_active = True
+ p.face_restoration_model = 'CodeFormer'
+ elif opt == 'gfpgan':
+ is_active = True
+ p.face_restoration_model = 'GFPGAN'
+ else:
+ is_active = opt in ('true', 'yes', 'y', '1')
+
+ p.restore_faces = is_active
+
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -205,6 +223,8 @@ axis_options = [
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
+ AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
+ AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
]
@@ -213,49 +233,47 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
- # Temporary list of all the images that are generated to be populated into the grid.
- # Will be filled with empty images for any individual step that fails to process properly
- image_cache = [None] * (len(xs) * len(ys) * len(zs))
+ list_size = (len(xs) * len(ys) * len(zs))
processed_result = None
- cell_mode = "P"
- cell_size = (1, 1)
- state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
+ state.job_count = list_size * p.n_iter
def process_cell(x, y, z, ix, iy, iz):
- nonlocal image_cache, processed_result, cell_mode, cell_size
+ nonlocal processed_result
def index(ix, iy, iz):
return ix + iy * len(xs) + iz * len(xs) * len(ys)
- state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
+ state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
processed: Processed = cell(x, y, z)
- try:
- # this dereference will throw an exception if the image was not processed
- # (this happens in cases such as if the user stops the process from the UI)
- processed_image = processed.images[0]
-
- if processed_result is None:
- # Use our first valid processed result as a template container to hold our full results
- processed_result = copy(processed)
- cell_mode = processed_image.mode
- cell_size = processed_image.size
- processed_result.images = [Image.new(cell_mode, cell_size)]
- processed_result.all_prompts = [processed.prompt]
- processed_result.all_seeds = [processed.seed]
- processed_result.infotexts = [processed.infotexts[0]]
-
- image_cache[index(ix, iy, iz)] = processed_image
- if include_lone_images:
- processed_result.images.append(processed_image)
- processed_result.all_prompts.append(processed.prompt)
- processed_result.all_seeds.append(processed.seed)
- processed_result.infotexts.append(processed.infotexts[0])
- except:
- image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
+ if processed_result is None:
+ # Use our first processed result object as a template container to hold our full results
+ processed_result = copy(processed)
+ processed_result.images = [None] * list_size
+ processed_result.all_prompts = [None] * list_size
+ processed_result.all_seeds = [None] * list_size
+ processed_result.infotexts = [None] * list_size
+ processed_result.index_of_first_image = 1
+
+ idx = index(ix, iy, iz)
+ if processed.images:
+ # Non-empty list indicates some degree of success.
+ processed_result.images[idx] = processed.images[0]
+ processed_result.all_prompts[idx] = processed.prompt
+ processed_result.all_seeds[idx] = processed.seed
+ processed_result.infotexts[idx] = processed.infotexts[0]
+ else:
+ cell_mode = "P"
+ cell_size = (processed_result.width, processed_result.height)
+ if processed_result.images[0] is not None:
+ cell_mode = processed_result.images[0].mode
+ #This corrects size in case of batches:
+ cell_size = processed_result.images[0].size
+ processed_result.images[idx] = Image.new(cell_mode, cell_size)
+
if first_axes_processed == 'x':
for ix, x in enumerate(xs):
@@ -289,36 +307,48 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
process_cell(x, y, z, ix, iy, iz)
if not processed_result:
+ # Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
+ print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")
+ return Processed(p, [])
+ elif not any(processed_result.images):
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
return Processed(p, [])
- sub_grids = [None] * len(zs)
- for i in range(len(zs)):
- start_index = i * len(xs) * len(ys)
+ z_count = len(zs)
+ sub_grids = [None] * z_count
+ for i in range(z_count):
+ start_index = (i * len(xs) * len(ys)) + i
end_index = start_index + len(xs) * len(ys)
- grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
+ grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))
if draw_legend:
- grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
- sub_grids[i] = grid
- if include_sub_grids and len(zs) > 1:
- processed_result.images.insert(i+1, grid)
-
- sub_grid_size = sub_grids[0].size
- z_grid = images.image_grid(sub_grids, rows=1)
+ grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size)
+ processed_result.images.insert(i, grid)
+ processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])
+ processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])
+ processed_result.infotexts.insert(i, processed_result.infotexts[start_index])
+
+ sub_grid_size = processed_result.images[0].size
+ z_grid = images.image_grid(processed_result.images[:z_count], rows=1)
if draw_legend:
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
- processed_result.images[0] = z_grid
+ processed_result.images.insert(0, z_grid)
+ #TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
+ #processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
+ #processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
+ processed_result.infotexts.insert(0, processed_result.infotexts[0])
- return processed_result, sub_grids
+ return processed_result
class SharedSettingsStackHelper(object):
def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.vae = opts.sd_vae
+ self.uni_pc_order = opts.uni_pc_order
def __exit__(self, exc_type, exc_value, tb):
opts.data["sd_vae"] = self.vae
+ opts.data["uni_pc_order"] = self.uni_pc_order
modules.sd_models.reload_model_weights()
modules.sd_vae.reload_vae_weights()
@@ -418,7 +448,7 @@ class Script(scripts.Script):
if opt.label == 'Nothing':
return [0]
- valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
if opt.type == int:
valslist_ext = []
@@ -484,6 +514,10 @@ class Script(scripts.Script):
z_opt = self.current_axis_options[z_type]
zs = process_axis(z_opt, z_values)
+ # this could be moved to common code, but unlikely to be ever triggered anywhere else
+ grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
+ assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
+
def fix_axis_seeds(axis_opt, axis_list):
if axis_opt.label in ['Seed', 'Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
@@ -533,7 +567,7 @@ class Script(scripts.Script):
# If one of the axes is very slow to change between (like SD model
# checkpoint), then make sure it is in the outer iteration of the nested
# `for` loop.
- first_axes_processed = 'x'
+ first_axes_processed = 'z'
second_axes_processed = 'y'
if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
first_axes_processed = 'x'
@@ -593,7 +627,7 @@ class Script(scripts.Script):
return res
with SharedSettingsStackHelper():
- processed, sub_grids = draw_xyz_grid(
+ processed = draw_xyz_grid(
p,
xs=xs,
ys=ys,
@@ -610,11 +644,30 @@ class Script(scripts.Script):
margin_size=margin_size
)
- if opts.grid_save and len(sub_grids) > 1:
- for sub_grid in sub_grids:
- images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+ if not processed.images:
+ # It broke, no further handling needed.
+ return processed
+
+ z_count = len(zs)
+
+ if not include_lone_images:
+ # Don't need sub-images anymore, drop from list:
+ processed.images = processed.images[:z_count+1]
if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+ # Auto-save main and sub-grids:
+ grid_count = z_count + 1 if z_count > 1 else 1
+ for g in range(grid_count):
+ #TODO: See previous comment about intentional data misalignment.
+ adj_g = g-1 if g > 0 else g
+ images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
+
+ if not include_sub_grids:
+ # Done with sub-grids, drop all related information:
+ for sg in range(z_count):
+ del processed.images[1]
+ del processed.all_prompts[1]
+ del processed.all_seeds[1]
+ del processed.infotexts[1]
return processed
diff --git a/style.css b/style.css
index 05572f66..3eac2b17 100644
--- a/style.css
+++ b/style.css
@@ -362,6 +362,46 @@ input[type="range"]{
height: 100%;
}
+.popup-metadata{
+ color: black;
+ background: white;
+ display: inline-block;
+ padding: 1em;
+ white-space: pre-wrap;
+}
+
+.global-popup{
+ display: flex;
+ position: fixed;
+ z-index: 1001;
+ left: 0;
+ top: 0;
+ width: 100%;
+ height: 100%;
+ overflow: auto;
+ background-color: rgba(20, 20, 20, 0.95);
+}
+
+
+.global-popup-close:before {
+ content: "×";
+}
+
+.global-popup-close{
+ position: fixed;
+ right: 0.25em;
+ top: 0;
+ cursor: pointer;
+ color: white;
+ font-size: 32pt;
+}
+
+.global-popup-inner{
+ display: inline-block;
+ margin: auto;
+ padding: 2em;
+}
+
#lightboxModal{
display: none;
position: fixed;
@@ -436,9 +476,7 @@ input[type="range"]{
#modalImage {
display: block;
- margin-left: auto;
- margin-right: auto;
- margin-top: auto;
+ margin: auto;
width: auto;
}
@@ -839,6 +877,27 @@ footer {
margin-left: 0.5em;
}
+
+.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{
+ content: "🛈";
+}
+.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
+ display: none;
+ position: absolute;
+ right: 0;
+ color: white;
+ text-shadow: 2px 2px 3px black;
+ padding: 0.25em;
+ font-size: 22pt;
+}
+.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
+ display: inline-block;
+}
+.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{
+ color: red;
+}
+
+
.extra-network-thumbs {
display: flex;
flex-flow: row wrap;
@@ -856,7 +915,7 @@ footer {
}
.extra-network-thumbs .card:hover .additional a {
- display: block;
+ display: inline-block;
}
.extra-network-thumbs .actions .additional a {
@@ -939,6 +998,17 @@ footer {
line-break: anywhere;
}
+.extra-network-cards .card .actions .description {
+ display: block;
+ max-height: 3em;
+ white-space: pre-wrap;
+ line-height: 1.1;
+}
+
+.extra-network-cards .card .actions .description:hover {
+ max-height: none;
+}
+
.extra-network-cards .card .actions:hover .additional{
display: block;
}
diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py
index 0170c511..8ed98747 100644
--- a/test/basic_features/extras_test.py
+++ b/test/basic_features/extras_test.py
@@ -1,7 +1,9 @@
+import os
import unittest
import requests
from gradio.processing_utils import encode_pil_to_base64
from PIL import Image
+from modules.paths import script_path
class TestExtrasWorking(unittest.TestCase):
def setUp(self):
@@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase):
"upscaler_1": "None",
"upscaler_2": "None",
"extras_upscaler_2_visibility": 0,
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
+ "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
}
def test_simple_upscaling_performed(self):
@@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase):
def setUp(self):
self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
self.png_info = {
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
+ "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
}
def test_png_info_performed(self):
@@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase):
def setUp(self):
self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
self.interrogate = {
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")),
+ "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))),
"model": "clip"
}
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
index 08c5c903..5240ec36 100644
--- a/test/basic_features/img2img_test.py
+++ b/test/basic_features/img2img_test.py
@@ -1,14 +1,16 @@
+import os
import unittest
import requests
from gradio.processing_utils import encode_pil_to_base64
from PIL import Image
+from modules.paths import script_path
class TestImg2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
self.simple_img2img = {
- "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
+ "init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))],
"resize_mode": 0,
"denoising_strength": 0.75,
"mask": None,
@@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
+ self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_with_inverted_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
+ self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.simple_img2img["inpainting_mask_invert"] = True
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py
index 5aa43a44..cb525fbb 100644
--- a/test/basic_features/txt2img_test.py
+++ b/test/basic_features/txt2img_test.py
@@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
self.simple_txt2img["sampler_index"] = "DDIM"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+ self.simple_txt2img["sampler_index"] = "UniPC"
+ self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_multiple_batches_performed(self):
self.simple_txt2img["n_iter"] = 2
diff --git a/test/server_poll.py b/test/server_poll.py
index 42d56a4c..c732630f 100644
--- a/test/server_poll.py
+++ b/test/server_poll.py
@@ -1,6 +1,8 @@
import unittest
import requests
import time
+import os
+from modules.paths import script_path
def run_tests(proc, test_dir):
@@ -15,8 +17,8 @@ def run_tests(proc, test_dir):
break
if proc.poll() is None:
if test_dir is None:
- test_dir = "test"
- suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
+ test_dir = os.path.join(script_path, "test")
+ suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir)
result = unittest.TextTestRunner(verbosity=2).run(suite)
return len(result.failures) + len(result.errors)
else:
diff --git a/webui.py b/webui.py
index 9e8b486a..aaec79fd 100644
--- a/webui.py
+++ b/webui.py
@@ -12,11 +12,22 @@ from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
-from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules import paths, timer, import_hook, errors
+
+startup_timer = timer.Timer()
import torch
+startup_timer.record("import torch")
+
+import gradio
+startup_timer.record("import gradio")
+
+import ldm.modules.encoders.modules
+startup_timer.record("import ldm")
+
+from modules import extra_networks, ui_extra_networks_checkpoints
+from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
+from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
@@ -30,7 +41,6 @@ import modules.gfpgan_model as gfpgan
import modules.img2img
import modules.lowvram
-import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
@@ -45,6 +55,8 @@ from modules import modelloader
from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
+startup_timer.record("other imports")
+
if cmd_opts.server_name:
server_name = cmd_opts.server_name
@@ -88,6 +100,7 @@ def initialize():
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
+ startup_timer.record("list extensions")
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
@@ -96,16 +109,28 @@ def initialize():
modelloader.cleanup_models()
modules.sd_models.setup_model()
+ startup_timer.record("list SD models")
+
codeformer.setup_model(cmd_opts.codeformer_models_path)
+ startup_timer.record("setup codeformer")
+
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
+ startup_timer.record("setup gfpgan")
modelloader.list_builtin_upscalers()
+ startup_timer.record("list builtin upscalers")
+
modules.scripts.load_scripts()
+ startup_timer.record("load scripts")
+
modelloader.load_upscalers()
+ startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
+ startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
+ startup_timer.record("refresh textual inversion templates")
try:
modules.sd_models.load_model()
@@ -114,6 +139,7 @@ def initialize():
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
+ startup_timer.record("load SD checkpoint")
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
@@ -121,8 +147,10 @@ def initialize():
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ startup_timer.record("opts onchange")
shared.reload_hypernetworks()
+ startup_timer.record("reload hypernets")
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
@@ -131,6 +159,7 @@ def initialize():
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+ startup_timer.record("extra networks")
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
@@ -144,6 +173,7 @@ def initialize():
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
+ startup_timer.record("TLS")
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
@@ -153,13 +183,16 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler)
-def setup_cors(app):
+def setup_middleware(app):
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
def create_api(app):
@@ -183,12 +216,12 @@ def api_only():
initialize()
app = FastAPI()
- setup_cors(app)
- app.add_middleware(GZipMiddleware, minimum_size=1000)
+ setup_middleware(app)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
+ print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
@@ -199,21 +232,24 @@ def webui():
while 1:
if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
+ startup_timer.record("cleanup temp dir")
modules.script_callbacks.before_ui_callback()
+ startup_timer.record("scripts before_ui_callback")
shared.demo = modules.ui.create_ui()
+ startup_timer.record("create ui")
if cmd_opts.gradio_queue:
shared.demo.queue(64)
gradio_auth_creds = []
if cmd_opts.gradio_auth:
- gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',')
+ gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
- gradio_auth_creds += [x.strip() for x in line.split(',')]
+ gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
@@ -229,15 +265,15 @@ def webui():
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
+ startup_timer.record("gradio launch")
+
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attacker wants, including installing an extension and
# running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
- setup_cors(app)
-
- app.add_middleware(GZipMiddleware, minimum_size=1000)
+ setup_middleware(app)
modules.progress.setup_progress_api(app)
@@ -247,28 +283,42 @@ def webui():
ui_extra_networks.add_pages_to_demo(app)
modules.script_callbacks.app_started_callback(shared.demo, app)
+ startup_timer.record("scripts app_started_callback")
+
+ print(f"Startup time: {startup_timer.summary()}.")
wait_on_server(shared.demo)
print('Restarting UI...')
+ startup_timer.reset()
+
sd_samplers.set_samplers()
modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions()
+ startup_timer.record("list extensions")
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
+ startup_timer.record("load scripts")
+
modules.script_callbacks.model_loaded_callback(shared.sd_model)
+ startup_timer.record("model loaded callback")
+
modelloader.load_upscalers()
+ startup_timer.record("load upscalers")
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
+ startup_timer.record("reload script modules")
modules.sd_models.list_models()
+ startup_timer.record("list SD models")
shared.reload_hypernetworks()
+ startup_timer.record("reload hypernetworks")
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
@@ -277,6 +327,7 @@ def webui():
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+ startup_timer.record("initialize extra networks")
if __name__ == "__main__":