aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/feature_request.yml2
-rw-r--r--README.md10
-rw-r--r--html/licenses.html29
-rw-r--r--javascript/hints.js10
-rw-r--r--javascript/hires_fix.js25
-rw-r--r--modules/api/api.py48
-rw-r--r--modules/api/models.py6
-rw-r--r--modules/generation_parameters_copypaste.py17
-rw-r--r--modules/processing.py30
-rw-r--r--modules/sd_hijack.py28
-rw-r--r--modules/sd_hijack_clip.py4
-rw-r--r--modules/sd_hijack_optimizations.py125
-rw-r--r--modules/sd_vae.py20
-rw-r--r--modules/shared.py5
-rw-r--r--modules/sub_quadratic_attention.py214
-rw-r--r--modules/textual_inversion/textual_inversion.py166
-rw-r--r--modules/ui.py23
-rw-r--r--requirements.txt2
-rw-r--r--screenshot.pngbin525075 -> 420577 bytes
-rw-r--r--scripts/sd_upscale.py2
-rw-r--r--style.css17
-rw-r--r--test/basic_features/img2img_test.py6
-rw-r--r--txt2img_Screenshot.pngbin337094 -> 0 bytes
23 files changed, 638 insertions, 151 deletions
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
index 8ca6e21f..35a88740 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.yml
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -1,7 +1,7 @@
name: Feature request
description: Suggest an idea for this project
title: "[Feature Request]: "
-labels: ["suggestion"]
+labels: ["enhancement"]
body:
- type: checkboxes
diff --git a/README.md b/README.md
index 88250a6b..d783fdf0 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,7 @@
# Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion.
-![](txt2img_Screenshot.png)
-
-Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
+![](screenshot.png)
## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
@@ -97,9 +95,8 @@ Alternatively, use online services (like Google Colab):
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
-4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
-5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
-6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
+4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
+5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
### Automatic Installation on Linux
1. Install the dependencies:
@@ -141,6 +138,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
+- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
diff --git a/html/licenses.html b/html/licenses.html
index 9eeaa072..570630eb 100644
--- a/html/licenses.html
+++ b/html/licenses.html
@@ -184,7 +184,7 @@ SOFTWARE.
</pre>
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
-<small>Code added by contirubtors, most likely copied from this repository.</small>
+<small>Code added by contributors, most likely copied from this repository.</small>
<pre>
Apache License
@@ -390,3 +390,30 @@ SOFTWARE.
limitations under the License.
</pre>
+<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
+<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
+<pre>
+MIT License
+
+Copyright (c) 2023 Alex Birch
+Copyright (c) 2023 Amin Rezaei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+</pre>
+
diff --git a/javascript/hints.js b/javascript/hints.js
index dda66e09..856e1389 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -4,7 +4,7 @@ titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image",
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
- "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
+ "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
@@ -74,7 +74,7 @@ titles = {
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Apply style": "Insert selected styles into prompt fields",
- "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
+ "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style uses that as a placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
@@ -92,12 +92,12 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
- "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
+ "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
- "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
- "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.",
+ "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality.",
+ "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality.",
"Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js
new file mode 100644
index 00000000..07fba549
--- /dev/null
+++ b/javascript/hires_fix.js
@@ -0,0 +1,25 @@
+
+function setInactive(elem, inactive){
+ console.log(elem)
+ if(inactive){
+ elem.classList.add('inactive')
+ } else{
+ elem.classList.remove('inactive')
+ }
+}
+
+function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
+ console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y)
+
+ hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
+ hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
+ hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
+
+ gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
+
+ setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
+ setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
+ setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
+
+ return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
+}
diff --git a/modules/api/api.py b/modules/api/api.py
index 2103709b..5b6125f8 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras
@@ -28,8 +28,13 @@ def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
+def script_name_to_index(name, scripts):
+ try:
+ return [script.title().lower() for script in scripts].index(name.lower())
+ except:
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
@@ -143,7 +148,21 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+ def get_script(self, script_name, script_runner):
+ if script_name is None:
+ return None, None
+
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+
+ script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
+ script = script_runner.selectable_scripts[script_idx]
+ return script, script_idx
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
+
populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
@@ -153,14 +172,22 @@ class Api:
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
+ args = vars(populate)
+ args.pop('script_name', None)
+
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
-
b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -170,6 +197,8 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
+ script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
+
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
@@ -186,13 +215,20 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+ args.pop('script_name', None)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
diff --git a/modules/api/models.py b/modules/api/models.py
index d8198a27..ce43c858 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -100,13 +100,13 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
class TextToImageResponse(BaseModel):
@@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel):
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
- upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 12a9de3d..f7f68b67 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res):
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
+ if shared.opts.use_old_hires_fix_width_height:
+ hires_width = int(res.get("Hires resize-1", None))
+ hires_height = int(res.get("Hires resize-2", None))
+
+ if hires_width is not None and hires_height is not None:
+ res['Size-1'] = hires_width
+ res['Size-2'] = hires_height
+ return
+
if firstpass_width is None or firstpass_height is None:
return
@@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
- # old algorithm for auto-calculating first pass size
- desired_pixel_count = 512 * 512
- actual_pixel_count = width * height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- firstpass_width = math.ceil(scale * width / 64) * 64
- firstpass_height = math.ceil(scale * height / 64) * 64
+ from modules import processing
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
diff --git a/modules/processing.py b/modules/processing.py
index 82157bc9..f04a0e1e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -98,7 +98,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -149,7 +149,7 @@ class StableDiffusionProcessing():
self.seed_resize_from_w = 0
self.scripts = None
- self.script_args = None
+ self.script_args = script_args
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
@@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
return res
+def old_hires_fix_first_pass_dimensions(width, height):
+ """old algorithm for auto-calculating first pass size"""
+
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ width = math.ceil(scale * width / 64) * 64
+ height = math.ceil(scale * height / 64) * 64
+
+ return width, height
+
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
@@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_upscale_to_y = hr_resize_y
if firstphase_width != 0 or firstphase_height != 0:
- print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
- self.hr_scale = self.width / firstphase_width
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
self.width = firstphase_width
self.height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
+ self.applied_old_hires_behavior_to = None
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
+ self.hr_resize_x = self.width
+ self.hr_resize_y = self.height
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
+
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
+ self.applied_old_hires_behavior_to = (self.width, self.height)
+
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
self.extra_generation_params["Hires upscale"] = self.hr_scale
self.hr_upscale_to_x = int(self.width * self.hr_scale)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 71cc145a..6b0d95af 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
-from modules.sd_hijack_optimizations import invokeAI_mps_available
-
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
@@ -43,20 +41,19 @@ def apply_optimizations():
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
+ elif cmd_opts.opt_sub_quad_attention:
+ print("Applying sub-quadratic cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
+ optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
- if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- else:
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
@@ -86,10 +83,12 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
- def hijack(self, m):
+ def __init__(self):
+ self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -120,7 +119,6 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
-
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 5520c9b2..852afc66 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z *= original_mean / new_mean
+ z = z * (original_mean / new_mean)
return z
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 02c87f40..cdc63ed7 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,7 @@
import math
import sys
import traceback
-import importlib
+import psutil
import torch
from torch import einsum
@@ -12,6 +12,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
+from .sub_quadratic_attention import efficient_dot_product_attention
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
print(traceback.format_exc(), file=sys.stderr)
+def get_available_vram():
+ if shared.device.type == 'cuda':
+ stats = torch.cuda.memory_stats(shared.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ return mem_free_total
+ else:
+ return psutil.virtual_memory().available
+
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
@@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
+ mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
-def check_for_psutil():
- try:
- spec = importlib.util.find_spec('psutil')
- return spec is not None
- except ModuleNotFoundError:
- return False
-
-invokeAI_mps_available = check_for_psutil()
-
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
-if invokeAI_mps_available:
- import psutil
- mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
@@ -215,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
+# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
+# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
+def sub_quad_attention_forward(self, x, context=None, mask=None):
+ assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
+
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+
+ x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
+
+ out_proj, dropout = self.to_out
+ x = out_proj(x)
+ x = dropout(x)
+
+ return x
+
+def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
+ bytes_per_token = torch.finfo(q.dtype).bits//8
+ batch_x_heads, q_tokens, _ = q.shape
+ _, k_tokens, _ = k.shape
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
+
+ if chunk_threshold is None:
+ chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+ elif chunk_threshold == 0:
+ chunk_threshold_bytes = None
+ else:
+ chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
+
+ if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
+ kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
+ elif kv_chunk_size_min == 0:
+ kv_chunk_size_min = None
+
+ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
+ # i.e. send it down the unchunked fast-path
+ query_chunk_size = q_tokens
+ kv_chunk_size = k_tokens
+
+ return efficient_dot_product_attention(
+ q,
+ k,
+ v,
+ query_chunk_size=q_chunk_size,
+ kv_chunk_size=kv_chunk_size,
+ kv_chunk_size_min = kv_chunk_size_min,
+ use_checkpoint=use_checkpoint,
+ )
+
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@@ -252,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
h_ = torch.zeros_like(k, device=q.device)
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
+ mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
@@ -312,3 +371,19 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
+def sub_quad_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index ac71d62d..0a49daa1 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,8 +1,9 @@
import torch
+import safetensors.torch
import os
import collections
from collections import namedtuple
-from modules import shared, devices, script_callbacks
+from modules import shared, devices, script_callbacks, sd_models
from modules.paths import models_path
import glob
from copy import deepcopy
@@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
+ *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
@@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print(f"Using VAE found similar to selected model: {vae_file}")
+ # if still not found, try look for ".vae.safetensors" beside model
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.safetensors"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
@@ -163,8 +172,9 @@ def load_vae(model, vae_file=None):
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+
+ vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
@@ -195,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
+
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
+
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
diff --git a/modules/shared.py b/modules/shared.py
index 865c3c07..a1e10201 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
+parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
+parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
+parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
@@ -394,6 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+ "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
new file mode 100644
index 00000000..55052815
--- /dev/null
+++ b/modules/sub_quadratic_attention.py
@@ -0,0 +1,214 @@
+# original source:
+# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
+# license:
+# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
+# credit:
+# Amin Rezaei (original author)
+# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
+# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
+# implementation of:
+# Self-attention Does Not Need O(n2) Memory":
+# https://arxiv.org/abs/2112.05682v2
+
+from functools import partial
+import torch
+from torch import Tensor
+from torch.utils.checkpoint import checkpoint
+import math
+from typing import Optional, NamedTuple, List
+
+
+def narrow_trunc(
+ input: Tensor,
+ dim: int,
+ start: int,
+ length: int
+) -> Tensor:
+ return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
+
+
+class AttnChunk(NamedTuple):
+ exp_values: Tensor
+ exp_weights_sum: Tensor
+ max_score: Tensor
+
+
+class SummarizeChunk:
+ @staticmethod
+ def __call__(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ ) -> AttnChunk: ...
+
+
+class ComputeQueryChunkAttn:
+ @staticmethod
+ def __call__(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ ) -> Tensor: ...
+
+
+def _summarize_chunk(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ scale: float,
+) -> AttnChunk:
+ attn_weights = torch.baddbmm(
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ query,
+ key.transpose(1,2),
+ alpha=scale,
+ beta=0,
+ )
+ max_score, _ = torch.max(attn_weights, -1, keepdim=True)
+ max_score = max_score.detach()
+ exp_weights = torch.exp(attn_weights - max_score)
+ exp_values = torch.bmm(exp_weights, value)
+ max_score = max_score.squeeze(-1)
+ return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
+
+def _query_chunk_attention(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ summarize_chunk: SummarizeChunk,
+ kv_chunk_size: int,
+) -> Tensor:
+ batch_x_heads, k_tokens, k_channels_per_head = key.shape
+ _, _, v_channels_per_head = value.shape
+
+ def chunk_scanner(chunk_idx: int) -> AttnChunk:
+ key_chunk = narrow_trunc(
+ key,
+ 1,
+ chunk_idx,
+ kv_chunk_size
+ )
+ value_chunk = narrow_trunc(
+ value,
+ 1,
+ chunk_idx,
+ kv_chunk_size
+ )
+ return summarize_chunk(query, key_chunk, value_chunk)
+
+ chunks: List[AttnChunk] = [
+ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
+ ]
+ acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
+ chunk_values, chunk_weights, chunk_max = acc_chunk
+
+ global_max, _ = torch.max(chunk_max, 0, keepdim=True)
+ max_diffs = torch.exp(chunk_max - global_max)
+ chunk_values *= torch.unsqueeze(max_diffs, -1)
+ chunk_weights *= max_diffs
+
+ all_values = chunk_values.sum(dim=0)
+ all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
+ return all_values / all_weights
+
+
+# TODO: refactor CrossAttention#get_attention_scores to share code with this
+def _get_attention_scores_no_kv_chunking(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ scale: float,
+) -> Tensor:
+ attn_scores = torch.baddbmm(
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ query,
+ key.transpose(1,2),
+ alpha=scale,
+ beta=0,
+ )
+ attn_probs = attn_scores.softmax(dim=-1)
+ del attn_scores
+ hidden_states_slice = torch.bmm(attn_probs, value)
+ return hidden_states_slice
+
+
+class ScannedChunk(NamedTuple):
+ chunk_idx: int
+ attn_chunk: AttnChunk
+
+
+def efficient_dot_product_attention(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ query_chunk_size=1024,
+ kv_chunk_size: Optional[int] = None,
+ kv_chunk_size_min: Optional[int] = None,
+ use_checkpoint=True,
+):
+ """Computes efficient dot-product attention given query, key, and value.
+ This is efficient version of attention presented in
+ https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
+ Args:
+ query: queries for calculating attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ key: keys for calculating attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ value: values to be used in attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ query_chunk_size: int: query chunks size
+ kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
+ kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
+ use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
+ Returns:
+ Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
+ """
+ batch_x_heads, q_tokens, q_channels_per_head = query.shape
+ _, k_tokens, _ = key.shape
+ scale = q_channels_per_head ** -0.5
+
+ kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
+ if kv_chunk_size_min is not None:
+ kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
+
+ def get_query_chunk(chunk_idx: int) -> Tensor:
+ return narrow_trunc(
+ query,
+ 1,
+ chunk_idx,
+ min(query_chunk_size, q_tokens)
+ )
+
+ summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
+ summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
+ compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
+ _get_attention_scores_no_kv_chunking,
+ scale=scale
+ ) if k_tokens <= kv_chunk_size else (
+ # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
+ partial(
+ _query_chunk_attention,
+ kv_chunk_size=kv_chunk_size,
+ summarize_chunk=summarize_chunk,
+ )
+ )
+
+ if q_tokens <= query_chunk_size:
+ # fast-path for when there's just 1 query chunk
+ return compute_query_chunk_attn(
+ query=query,
+ key=key,
+ value=value,
+ )
+
+ # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
+ # and pass slices to be mutated, instead of torch.cat()ing the returned slices
+ res = torch.cat([
+ compute_query_chunk_attn(
+ query=get_query_chunk(i * query_chunk_size),
+ key=key,
+ value=value,
+ ) for i in range(math.ceil(q_tokens / query_chunk_size))
+ ], dim=1)
+ return res
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 45882ed6..217fe9eb 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -66,17 +66,41 @@ class Embedding:
return self.cached_checksum
+class DirWithTextualInversionEmbeddings:
+ def __init__(self, path):
+ self.path = path
+ self.mtime = None
+
+ def has_changed(self):
+ if not os.path.isdir(self.path):
+ return False
+
+ mt = os.path.getmtime(self.path)
+ if self.mtime is None or mt > self.mtime:
+ return True
+
+ def update(self):
+ if not os.path.isdir(self.path):
+ return
+
+ self.mtime = os.path.getmtime(self.path)
+
+
class EmbeddingDatabase:
- def __init__(self, embeddings_dir):
+ def __init__(self):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
- self.dir_mtime = None
- self.embeddings_dir = embeddings_dir
self.expected_shape = -1
+ self.embedding_dirs = {}
- def register_embedding(self, embedding, model):
+ def add_embedding_dir(self, path):
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
+
+ def clear_embedding_dirs(self):
+ self.embedding_dirs.clear()
+ def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenize([embedding.name])[0]
@@ -93,65 +117,62 @@ class EmbeddingDatabase:
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
- def load_textual_inversion_embeddings(self, force_reload = False):
- mt = os.path.getmtime(self.embeddings_dir)
- if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
- return
+ def load_from_file(self, path, filename):
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
- self.skipped_embeddings.clear()
- self.expected_shape = self.get_expected_shape()
-
- def process_file(path, filename):
- name, ext = os.path.splitext(filename)
- ext = ext.upper()
-
- if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- name = data.get('name', name)
- else:
- data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
- elif ext in ['.BIN', '.PT']:
- data = torch.load(path, map_location="cpu")
- else:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ _, second_ext = os.path.splitext(name)
+ if second_ext.upper() == '.PREVIEW':
return
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
-
- if self.expected_shape == -1 or self.expected_shape == embedding.shape:
- self.register_embedding(embedding, shared.sd_model)
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
else:
- self.skipped_embeddings[name] = embedding
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ elif ext in ['.BIN', '.PT']:
+ data = torch.load(path, map_location="cpu")
+ else:
+ return
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
- for root, dirs, fns in os.walk(self.embeddings_dir):
+ def load_from_dir(self, embdir):
+ if not os.path.isdir(embdir.path):
+ return
+
+ for root, dirs, fns in os.walk(embdir.path):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -159,12 +180,32 @@ class EmbeddingDatabase:
if os.stat(fullfn).st_size == 0:
continue
- process_file(fullfn, fn)
+ self.load_from_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
+ def load_textual_inversion_embeddings(self, force_reload=False):
+ if not force_reload:
+ need_reload = False
+ for path, embdir in self.embedding_dirs.items():
+ if embdir.has_changed():
+ need_reload = True
+ break
+
+ if not need_reload:
+ return
+
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
+
+ for path, embdir in self.embedding_dirs.items():
+ self.load_from_dir(embdir)
+ embdir.update()
+
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
@@ -247,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
- assert steps > 0 , "Max steps must be positive"
+ assert steps > 0, "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
- assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert save_model_every >= 0, "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
- assert create_image_every >= 0 , "Create image must be positive or 0"
+ assert create_image_every >= 0, "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
diff --git a/modules/ui.py b/modules/ui.py
index 6c765262..719c26b3 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
with devices.autocast():
p.init([""], [0], [0])
- return f"resize to: <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
+ return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
@@ -745,15 +745,20 @@ def create_ui():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- hr_resolution_preview_args = dict(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False
- )
-
for input in hr_resolution_preview_inputs:
- input.change(**hr_resolution_preview_args)
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
diff --git a/requirements.txt b/requirements.txt
index 4f09385f..e1dbf8e5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -30,4 +30,4 @@ inflection
GitPython
torchsde
safetensors
-psutil; sys_platform == 'darwin'
+psutil
diff --git a/screenshot.png b/screenshot.png
index 86c3209f..47a1be4e 100644
--- a/screenshot.png
+++ b/screenshot.png
Binary files differ
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index 9b8ffd85..332d76d9 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -25,6 +25,8 @@ class Script(scripts.Script):
return [info, overlap, upscaler_index, scale_factor]
def run(self, p, _, overlap, upscaler_index, scale_factor):
+ if isinstance(upscaler_index, str):
+ upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
processing.fix_seed(p)
upscaler = shared.sd_upscalers[upscaler_index]
diff --git a/style.css b/style.css
index 76721756..ec5e4182 100644
--- a/style.css
+++ b/style.css
@@ -512,7 +512,7 @@ input[type="range"]{
border: none;
background: none;
flex: unset;
- gap: 0.5em;
+ gap: 1em;
}
#quicksettings > div > div{
@@ -521,6 +521,17 @@ input[type="range"]{
padding: 0;
}
+#quicksettings > div > div > div > div > label > span {
+ position: relative;
+ margin-right: 9em;
+ margin-bottom: -1em;
+}
+
+#quicksettings > div > div > label > span {
+ position: relative;
+ margin-bottom: -1em;
+}
+
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
@@ -659,6 +670,10 @@ footer {
min-width: auto;
}
+.inactive{
+ opacity: 0.5;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
index 0a9c1e8a..bd520b13 100644
--- a/test/basic_features/img2img_test.py
+++ b/test/basic_features/img2img_test.py
@@ -50,6 +50,12 @@ class TestImg2ImgWorking(unittest.TestCase):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+ def test_img2img_sd_upscale_performed(self):
+ self.simple_img2img["script_name"] = "sd upscale"
+ self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0]
+
+ self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+
if __name__ == "__main__":
unittest.main()
diff --git a/txt2img_Screenshot.png b/txt2img_Screenshot.png
deleted file mode 100644
index 6e2759a4..00000000
--- a/txt2img_Screenshot.png
+++ /dev/null
Binary files differ