aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.eslintrc.js9
-rw-r--r--javascript/aspectRatioOverlay.js2
-rw-r--r--javascript/contextMenus.js4
-rw-r--r--javascript/dragdrop.js53
-rw-r--r--javascript/generationParams.js2
-rw-r--r--javascript/hints.js81
-rw-r--r--javascript/imageMaskFix.js2
-rw-r--r--javascript/imageParams.js18
-rw-r--r--javascript/imageviewer.js2
-rw-r--r--javascript/notification.js2
-rw-r--r--javascript/token-counters.js83
-rw-r--r--javascript/ui.js73
-rw-r--r--modules/api/api.py2
-rw-r--r--modules/devices.py18
-rw-r--r--modules/errors.py8
-rw-r--r--modules/generation_parameters_copypaste.py18
-rw-r--r--modules/images.py31
-rw-r--r--modules/img2img.py3
-rw-r--r--modules/processing.py4
-rw-r--r--modules/script_callbacks.py20
-rw-r--r--modules/sd_hijack.py20
-rw-r--r--modules/sd_models.py24
-rw-r--r--modules/sd_samplers_kdiffusion.py50
-rw-r--r--modules/sd_unet.py92
-rw-r--r--modules/shared.py8
-rw-r--r--modules/shared_items.py11
-rw-r--r--modules/ui.py6
-rw-r--r--modules/ui_common.py9
-rw-r--r--requirements.txt2
-rw-r--r--requirements_versions.txt2
-rw-r--r--script.js63
-rw-r--r--scripts/xyz_grid.py6
-rw-r--r--webui.py23
-rwxr-xr-xwebui.sh9
34 files changed, 564 insertions, 196 deletions
diff --git a/.eslintrc.js b/.eslintrc.js
index 944cc869..f33aca09 100644
--- a/.eslintrc.js
+++ b/.eslintrc.js
@@ -50,13 +50,14 @@ module.exports = {
globals: {
//script.js
gradioApp: "readonly",
+ executeCallbacks: "readonly",
+ onAfterUiUpdate: "readonly",
+ onOptionsChanged: "readonly",
onUiLoaded: "readonly",
onUiUpdate: "readonly",
- onOptionsChanged: "readonly",
uiCurrentTab: "writable",
- uiElementIsVisible: "readonly",
uiElementInSight: "readonly",
- executeCallbacks: "readonly",
+ uiElementIsVisible: "readonly",
//ui.js
opts: "writable",
all_gallery_buttons: "readonly",
@@ -84,5 +85,7 @@ module.exports = {
// imageviewer.js
modalPrevImage: "readonly",
modalNextImage: "readonly",
+ // token-counters.js
+ setupTokenCounters: "readonly",
}
};
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js
index 1c08a1a9..2cf2d571 100644
--- a/javascript/aspectRatioOverlay.js
+++ b/javascript/aspectRatioOverlay.js
@@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) {
}
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if (arPreviewRect) {
arPreviewRect.style.display = 'none';
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
index f14af1d4..d60a10c4 100644
--- a/javascript/contextMenus.js
+++ b/javascript/contextMenus.js
@@ -167,6 +167,4 @@ var addContextMenuEventListener = initResponse[2];
})();
//End example Context Menu Items
-onUiUpdate(function() {
- addContextMenuEventListener();
-});
+onAfterUiUpdate(addContextMenuEventListener);
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
index 77a24a07..5803daea 100644
--- a/javascript/dragdrop.js
+++ b/javascript/dragdrop.js
@@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) {
}
}
+function eventHasFiles(e) {
+ if (!e.dataTransfer || !e.dataTransfer.files) return false;
+ if (e.dataTransfer.files.length > 0) return true;
+ if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true;
+
+ return false;
+}
+
+function dragDropTargetIsPrompt(target) {
+ if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true;
+ if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true;
+ return false;
+}
+
window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0];
- const imgWrap = target.closest('[data-testid="image"]');
- if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
- return;
- }
+ if (!eventHasFiles(e)) return;
+
+ var targetImage = target.closest('[data-testid="image"]');
+ if (!dragDropTargetIsPrompt(target) && !targetImage) return;
+
e.stopPropagation();
e.preventDefault();
e.dataTransfer.dropEffect = 'copy';
@@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => {
const target = e.composedPath()[0];
- if (target.placeholder.indexOf("Prompt") == -1) {
- return;
+ if (!eventHasFiles(e)) return;
+
+ if (dragDropTargetIsPrompt(target)) {
+ e.stopPropagation();
+ e.preventDefault();
+
+ let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
+
+ const imgParent = gradioApp().getElementById(prompt_target);
+ const files = e.dataTransfer.files;
+ const fileInput = imgParent.querySelector('input[type="file"]');
+ if (fileInput) {
+ fileInput.files = files;
+ fileInput.dispatchEvent(new Event('change'));
+ }
}
- const imgWrap = target.closest('[data-testid="image"]');
- if (!imgWrap) {
+
+ var targetImage = target.closest('[data-testid="image"]');
+ if (targetImage) {
+ e.stopPropagation();
+ e.preventDefault();
+ const files = e.dataTransfer.files;
+ dropReplaceImage(targetImage, files);
return;
}
- e.stopPropagation();
- e.preventDefault();
- const files = e.dataTransfer.files;
- dropReplaceImage(imgWrap, files);
});
window.addEventListener('paste', e => {
diff --git a/javascript/generationParams.js b/javascript/generationParams.js
index a877f8a5..7c0fd221 100644
--- a/javascript/generationParams.js
+++ b/javascript/generationParams.js
@@ -1,7 +1,7 @@
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined;
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img");
}
diff --git a/javascript/hints.js b/javascript/hints.js
index 46f342cb..05ae5f22 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -116,17 +116,25 @@ var titles = {
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
};
-function updateTooltipForSpan(span) {
- if (span.title) return; // already has a title
+function updateTooltip(element) {
+ if (element.title) return; // already has a title
- let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
+ let text = element.textContent;
+ let tooltip = localization[titles[text]] || titles[text];
if (!tooltip) {
- tooltip = localization[titles[span.value]] || titles[span.value];
+ let value = element.value;
+ if (value) tooltip = localization[titles[value]] || titles[value];
}
if (!tooltip) {
- for (const c of span.classList) {
+ // Gradio dropdown options have `data-value`.
+ let dataValue = element.dataset.value;
+ if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];
+ }
+
+ if (!tooltip) {
+ for (const c of element.classList) {
if (c in titles) {
tooltip = localization[titles[c]] || titles[c];
break;
@@ -135,34 +143,53 @@ function updateTooltipForSpan(span) {
}
if (tooltip) {
- span.title = tooltip;
+ element.title = tooltip;
}
}
-function updateTooltipForSelect(select) {
- if (select.onchange != null) return;
+// Nodes to check for adding tooltips.
+const tooltipCheckNodes = new Set();
+// Timer for debouncing tooltip check.
+let tooltipCheckTimer = null;
- select.onchange = function() {
- select.title = localization[titles[select.value]] || titles[select.value] || "";
- };
+function processTooltipCheckNodes() {
+ for (const node of tooltipCheckNodes) {
+ updateTooltip(node);
+ }
+ tooltipCheckNodes.clear();
}
-var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
-
-onUiUpdate(function(m) {
- m.forEach(function(record) {
- record.addedNodes.forEach(function(node) {
- if (observedTooltipElements[node.tagName]) {
- updateTooltipForSpan(node);
- }
- if (node.tagName == "SELECT") {
- updateTooltipForSelect(node);
+onUiUpdate(function(mutationRecords) {
+ for (const record of mutationRecords) {
+ if (record.type === "childList" && record.target.classList.contains("options")) {
+ // This smells like a Gradio dropdown menu having changed,
+ // so let's enqueue an update for the input element that shows the current value.
+ let wrap = record.target.parentNode;
+ let input = wrap?.querySelector("input");
+ if (input) {
+ input.title = ""; // So we'll even have a chance to update it.
+ tooltipCheckNodes.add(input);
}
-
- if (node.querySelectorAll) {
- node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
- node.querySelectorAll('select').forEach(updateTooltipForSelect);
+ }
+ for (const node of record.addedNodes) {
+ if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) {
+ if (!node.title) {
+ if (
+ node.tagName === "SPAN" ||
+ node.tagName === "BUTTON" ||
+ node.tagName === "P" ||
+ node.tagName === "INPUT" ||
+ (node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item
+ ) {
+ tooltipCheckNodes.add(node);
+ }
+ }
+ node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));
}
- });
- });
+ }
+ }
+ if (tooltipCheckNodes.size) {
+ clearTimeout(tooltipCheckTimer);
+ tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
+ }
});
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js
index 3c9b8a6f..900c56f3 100644
--- a/javascript/imageMaskFix.js
+++ b/javascript/imageMaskFix.js
@@ -39,5 +39,5 @@ function imageMaskResize() {
});
}
-onUiUpdate(imageMaskResize);
+onAfterUiUpdate(imageMaskResize);
window.addEventListener('resize', imageMaskResize);
diff --git a/javascript/imageParams.js b/javascript/imageParams.js
deleted file mode 100644
index 057e2d39..00000000
--- a/javascript/imageParams.js
+++ /dev/null
@@ -1,18 +0,0 @@
-window.onload = (function() {
- window.addEventListener('drop', e => {
- const target = e.composedPath()[0];
- if (target.placeholder.indexOf("Prompt") == -1) return;
-
- let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
-
- e.stopPropagation();
- e.preventDefault();
- const imgParent = gradioApp().getElementById(prompt_target);
- const files = e.dataTransfer.files;
- const fileInput = imgParent.querySelector('input[type="file"]');
- if (fileInput) {
- fileInput.files = files;
- fileInput.dispatchEvent(new Event('change'));
- }
- });
-});
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index 78e24eb9..677e95c1 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -170,7 +170,7 @@ function modalTileImageToggle(event) {
event.stopPropagation();
}
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox);
diff --git a/javascript/notification.js b/javascript/notification.js
index a68a76f2..76c5715d 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -4,7 +4,7 @@ let lastHeadImg = null;
let notificationButton = null;
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
if (notificationButton == null) {
notificationButton = gradioApp().getElementById('request_notifications');
diff --git a/javascript/token-counters.js b/javascript/token-counters.js
new file mode 100644
index 00000000..9d81a723
--- /dev/null
+++ b/javascript/token-counters.js
@@ -0,0 +1,83 @@
+let promptTokenCountDebounceTime = 800;
+let promptTokenCountTimeouts = {};
+var promptTokenCountUpdateFunctions = {};
+
+function update_txt2img_tokens(...args) {
+ // Called from Gradio
+ update_token_counter("txt2img_token_button");
+ if (args.length == 2) {
+ return args[0];
+ }
+ return args;
+}
+
+function update_img2img_tokens(...args) {
+ // Called from Gradio
+ update_token_counter("img2img_token_button");
+ if (args.length == 2) {
+ return args[0];
+ }
+ return args;
+}
+
+function update_token_counter(button_id) {
+ if (opts.disable_token_counters) {
+ return;
+ }
+ if (promptTokenCountTimeouts[button_id]) {
+ clearTimeout(promptTokenCountTimeouts[button_id]);
+ }
+ promptTokenCountTimeouts[button_id] = setTimeout(
+ () => gradioApp().getElementById(button_id)?.click(),
+ promptTokenCountDebounceTime,
+ );
+}
+
+
+function recalculatePromptTokens(name) {
+ promptTokenCountUpdateFunctions[name]?.();
+}
+
+function recalculate_prompts_txt2img() {
+ // Called from Gradio
+ recalculatePromptTokens('txt2img_prompt');
+ recalculatePromptTokens('txt2img_neg_prompt');
+ return Array.from(arguments);
+}
+
+function recalculate_prompts_img2img() {
+ // Called from Gradio
+ recalculatePromptTokens('img2img_prompt');
+ recalculatePromptTokens('img2img_neg_prompt');
+ return Array.from(arguments);
+}
+
+function setupTokenCounting(id, id_counter, id_button) {
+ var prompt = gradioApp().getElementById(id);
+ var counter = gradioApp().getElementById(id_counter);
+ var textarea = gradioApp().querySelector(`#${id} > label > textarea`);
+
+ if (opts.disable_token_counters) {
+ counter.style.display = "none";
+ return;
+ }
+
+ if (counter.parentElement == prompt.parentElement) {
+ return;
+ }
+
+ prompt.parentElement.insertBefore(counter, prompt);
+ prompt.parentElement.style.position = "relative";
+
+ promptTokenCountUpdateFunctions[id] = function() {
+ update_token_counter(id_button);
+ };
+ textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
+}
+
+function setupTokenCounters() {
+ setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
+ setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
+ setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
+ setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
+}
diff --git a/javascript/ui.js b/javascript/ui.js
index 648a5290..d70a681b 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) {
}
-var promptTokecountUpdateFuncs = {};
-
-function recalculatePromptTokens(name) {
- if (promptTokecountUpdateFuncs[name]) {
- promptTokecountUpdateFuncs[name]();
- }
-}
-
-function recalculate_prompts_txt2img() {
- recalculatePromptTokens('txt2img_prompt');
- recalculatePromptTokens('txt2img_neg_prompt');
- return Array.from(arguments);
-}
-
-function recalculate_prompts_img2img() {
- recalculatePromptTokens('img2img_prompt');
- recalculatePromptTokens('img2img_neg_prompt');
- return Array.from(arguments);
-}
-
-
var opts = {};
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
if (Object.keys(opts).length != 0) return;
var json_elem = gradioApp().getElementById('settings_json');
@@ -302,28 +281,7 @@ onUiUpdate(function() {
json_elem.parentElement.style.display = "none";
- function registerTextarea(id, id_counter, id_button) {
- var prompt = gradioApp().getElementById(id);
- var counter = gradioApp().getElementById(id_counter);
- var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
-
- if (counter.parentElement == prompt.parentElement) {
- return;
- }
-
- prompt.parentElement.insertBefore(counter, prompt);
- prompt.parentElement.style.position = "relative";
-
- promptTokecountUpdateFuncs[id] = function() {
- update_token_counter(id_button);
- };
- textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
- }
-
- registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
- registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
- registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
- registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
+ setupTokenCounters();
var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
var settings_tabs = gradioApp().querySelector('#settings div');
@@ -354,33 +312,6 @@ onOptionsChanged(function() {
});
let txt2img_textarea, img2img_textarea = undefined;
-let wait_time = 800;
-let token_timeouts = {};
-
-function update_txt2img_tokens(...args) {
- update_token_counter("txt2img_token_button");
- if (args.length == 2) {
- return args[0];
- }
- return args;
-}
-
-function update_img2img_tokens(...args) {
- update_token_counter(
- "img2img_token_button"
- );
- if (args.length == 2) {
- return args[0];
- }
- return args;
-}
-
-function update_token_counter(button_id) {
- if (token_timeouts[button_id]) {
- clearTimeout(token_timeouts[button_id]);
- }
- token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
-}
function restart_reload() {
document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
diff --git a/modules/api/api.py b/modules/api/api.py
index eee99bbb..6a456861 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -700,4 +700,4 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
diff --git a/modules/devices.py b/modules/devices.py
index d8a34a0f..1ed6ffdc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,5 +1,7 @@
import sys
import contextlib
+from functools import lru_cache
+
import torch
from modules import errors
@@ -154,3 +156,19 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
+
+
+@lru_cache
+def first_time_calculation():
+ """
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
+ spends about 2.7 seconds doing that, at least wih NVidia.
+ """
+
+ x = torch.zeros((1, 1)).to(device, dtype)
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
+ linear(x)
+
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
+ conv2d(x)
diff --git a/modules/errors.py b/modules/errors.py
index f6b80dbb..da4694f8 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -12,9 +12,13 @@ def print_error_explanation(message):
print('=' * max_len, file=sys.stderr)
-def display(e: Exception, task):
+def display(e: Exception, task, *, full_traceback=False):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ te = traceback.TracebackException.from_exception(e)
+ if full_traceback:
+ # include frames leading up to the try-catch block
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
+ print(*te.format(), sep="", file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d5f0a49b..071bd9ea 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -35,7 +35,7 @@ def reset():
def quote(text):
- if ',' not in str(text) and '\n' not in str(text):
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
return text
return json.dumps(text, ensure_ascii=False)
@@ -306,6 +306,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "RNG" not in res:
res["RNG"] = "GPU"
+ if "Schedule type" not in res:
+ res["Schedule type"] = "Automatic"
+
+ if "Schedule max sigma" not in res:
+ res["Schedule max sigma"] = 0
+
+ if "Schedule min sigma" not in res:
+ res["Schedule min sigma"] = 0
+
+ if "Schedule rho" not in res:
+ res["Schedule rho"] = 0
+
return res
@@ -318,6 +330,10 @@ infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
+ ('Schedule type', 'k_sched_type'),
+ ('Schedule max sigma', 'sigma_max'),
+ ('Schedule min sigma', 'sigma_min'),
+ ('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
diff --git a/modules/images.py b/modules/images.py
index 4e8cd993..e21e554c 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -21,6 +21,8 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
+import modules.sd_vae as sd_vae
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -336,8 +338,20 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator:
+ def get_vae_filename(self): #get the name of the VAE file.
+ if sd_vae.loaded_vae_file is None:
+ return "NoneType"
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
+ split_file_name = file_name.split('.')
+ if len(split_file_name) > 1 and split_file_name[0] == '':
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
+ else:
+ return split_file_name[0]
+
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.image.width,
@@ -354,19 +368,23 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
- 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
- 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
+ 'batch_size': lambda self: self.p.batch_size,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'vae_filename': lambda self: self.get_vae_filename(),
+
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt, image):
+ def __init__(self, p, seed, prompt, image, zip=False):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
+ self.zip = zip
def hasprompt(self, *args):
lower = self.prompt.lower()
@@ -665,9 +683,10 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity']:
+ items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
diff --git a/modules/img2img.py b/modules/img2img.py
index d704bf90..4c12c2c5 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -92,7 +92,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask).convert('L')
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
diff --git a/modules/processing.py b/modules/processing.py
index 29a3743f..b75f2515 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -674,6 +674,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 40f388a5..d2728e12 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -111,6 +111,7 @@ callback_map = dict(
callbacks_before_ui=[],
callbacks_on_reload=[],
callbacks_list_optimizers=[],
+ callbacks_list_unets=[],
)
@@ -271,6 +272,18 @@ def list_optimizers_callback():
return res
+def list_unets_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_unets']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_unets')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -430,3 +443,10 @@ def on_list_optimizers(callback):
to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)
+
+
+def on_list_unets(callback):
+ """register a function to be called when UI is making a list of alternative options for unet.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
+
+ add_callback(callback_map['callbacks_list_unets'], callback)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f93df0a6..487dfd60 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -43,7 +43,7 @@ def list_optimizers():
optimizers.extend(new_optimizers)
-def apply_optimizations():
+def apply_optimizations(option=None):
global current_optimizer
undo_optimizations()
@@ -60,7 +60,7 @@ def apply_optimizations():
current_optimizer.undo()
current_optimizer = None
- selection = shared.opts.cross_attention_optimization
+ selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
@@ -72,12 +72,13 @@ def apply_optimizations():
matching_optimizer = optimizers[0]
if matching_optimizer is not None:
- print(f"Applying optimization: {matching_optimizer.name}... ", end='')
+ print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
+ print("Disabling attention optimization")
return ''
@@ -155,9 +156,9 @@ class StableDiffusionModelHijack:
def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
- def apply_optimizations(self):
+ def apply_optimizations(self, option=None):
try:
- self.optimization_method = apply_optimizations()
+ self.optimization_method = apply_optimizations(option)
except Exception as e:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
@@ -194,6 +195,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
+ ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
+
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -215,6 +221,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b1afbaa7..232eb9c4 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -164,6 +164,7 @@ def model_hash(filename):
def select_checkpoint():
+ """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
@@ -171,14 +172,14 @@ def select_checkpoint():
return checkpoint_info
if len(checkpoints_list) == 0:
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
if shared.cmd_opts.ckpt is not None:
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {model_path}", file=sys.stderr)
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
+ error_message += f"\n - directory {model_path}"
if shared.cmd_opts.ckpt_dir is not None:
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
- exit(1)
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
+ raise FileNotFoundError(error_message)
checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
@@ -423,7 +424,7 @@ class SdModelData:
try:
load_model()
except Exception as e:
- errors.display(e, "loading stable diffusion model")
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
@@ -508,6 +509,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
+ with devices.autocast(), torch.no_grad():
+ sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+
+ timer.record("calculate empty prompt")
+
print(f"Model loaded in {timer.summary()}.")
return sd_model
@@ -527,6 +533,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ sd_unet.apply_unet("None")
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 59982fc9..e9ba2c61 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -44,6 +44,14 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
+k_diffusion_scheduler = {
+ 'Automatic': None,
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
+}
+
class CFGDenoiser(torch.nn.Module):
"""
@@ -125,6 +133,16 @@ class CFGDenoiser(torch.nn.Module):
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
+ # TODO add infotext entry
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ elif num_repeats > 0:
+ uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
@@ -255,6 +273,13 @@ class KDiffusionSampler:
try:
return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
except sd_samplers_common.InterruptedException:
return self.last_latent
@@ -294,6 +319,31 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif opts.k_sched_type != "Automatic":
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
+ sigmas_kwargs = {
+ 'sigma_min': sigma_min,
+ 'sigma_max': sigma_max,
+ }
+
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ p.extra_generation_params["Schedule type"] = opts.k_sched_type
+
+ if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
+ if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
+
+ default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
+
+ if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
+ sigmas_kwargs['rho'] = opts.rho
+ p.extra_generation_params["Schedule rho"] = opts.rho
+
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
new file mode 100644
index 00000000..6d708ad2
--- /dev/null
+++ b/modules/sd_unet.py
@@ -0,0 +1,92 @@
+import torch.nn
+import ldm.modules.diffusionmodules.openaimodel
+
+from modules import script_callbacks, shared, devices
+
+unet_options = []
+current_unet_option = None
+current_unet = None
+
+
+def list_unets():
+ new_unets = script_callbacks.list_unets_callback()
+
+ unet_options.clear()
+ unet_options.extend(new_unets)
+
+
+def get_unet_option(option=None):
+ option = option or shared.opts.sd_unet
+
+ if option == "None":
+ return None
+
+ if option == "Automatic":
+ name = shared.sd_model.sd_checkpoint_info.model_name
+
+ options = [x for x in unet_options if x.model_name == name]
+
+ option = options[0].label if options else "None"
+
+ return next(iter([x for x in unet_options if x.label == option]), None)
+
+
+def apply_unet(option=None):
+ global current_unet_option
+ global current_unet
+
+ new_option = get_unet_option(option)
+ if new_option == current_unet_option:
+ return
+
+ if current_unet is not None:
+ print(f"Dectivating unet: {current_unet.option.label}")
+ current_unet.deactivate()
+
+ current_unet_option = new_option
+ if current_unet_option is None:
+ current_unet = None
+
+ if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ shared.sd_model.model.diffusion_model.to(devices.device)
+
+ return
+
+ shared.sd_model.model.diffusion_model.to(devices.cpu)
+ devices.torch_gc()
+
+ current_unet = current_unet_option.create_unet()
+ current_unet.option = current_unet_option
+ print(f"Activating unet: {current_unet.option.label}")
+ current_unet.activate()
+
+
+class SdUnetOption:
+ model_name = None
+ """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
+
+ label = None
+ """name of the unet in UI"""
+
+ def create_unet(self):
+ """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
+ raise NotImplementedError()
+
+
+class SdUnet(torch.nn.Module):
+ def forward(self, x, timesteps, context, *args, **kwargs):
+ raise NotImplementedError()
+
+ def activate(self):
+ pass
+
+ def deactivate(self):
+ pass
+
+
+def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+ if current_unet is not None:
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+
diff --git a/modules/shared.py b/modules/shared.py
index 3099d1d2..4d59fbf1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -314,6 +314,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
+ "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
@@ -403,6 +404,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -423,6 +425,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -487,6 +490,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order").needs_restart(),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
@@ -515,6 +519,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 2a8713c8..7f306a06 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -29,3 +29,14 @@ def cross_attention_optimizations():
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+def sd_unet_items():
+ import modules.sd_unet
+
+ return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
+
+
+def refresh_unet_list():
+ import modules.sd_unet
+
+ modules.sd_unet.list_unets()
+
diff --git a/modules/ui.py b/modules/ui.py
index 361f596e..6189ceeb 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -616,7 +616,8 @@ def create_ui():
outputs=[
txt2img_prompt,
txt_prompt_img
- ]
+ ],
+ show_progress=False,
)
enable_hr.change(
@@ -902,7 +903,8 @@ def create_ui():
outputs=[
img2img_prompt,
img2img_prompt_img
- ]
+ ],
+ show_progress=False,
)
img2img_args = dict(
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 27ab3ebb..5a9204a4 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -50,9 +50,10 @@ def save_files(js_data, images, do_make_zip, index):
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format
start_index = 0
+ only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
+ only_one = True
images = [images[index]]
start_index = index
@@ -70,6 +71,7 @@ def save_files(js_data, images, do_make_zip, index):
is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
+ p.batch_index = image_index-1
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
@@ -83,7 +85,10 @@ def save_files(js_data, images, do_make_zip, index):
# Make Zip
if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
+ zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
+ namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
+ zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
+ zip_filepath = os.path.join(path, f"{zip_filename}.zip")
from zipfile import ZipFile
with ZipFile(zip_filepath, "w") as zip_file:
diff --git a/requirements.txt b/requirements.txt
index 34e4520d..a464447b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@ blendmodes
accelerate
basicsr
gfpgan
-gradio==3.31.0
+gradio==3.32.0
numpy
omegaconf
opencv-contrib-python
diff --git a/requirements_versions.txt b/requirements_versions.txt
index de501fda..31b179a9 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -3,7 +3,7 @@ transformers==4.25.1
accelerate==0.18.0
basicsr==1.4.2
gfpgan==1.3.8
-gradio==3.31.0
+gradio==3.32.0
numpy==1.23.5
Pillow==9.5.0
realesrgan==0.3.0
diff --git a/script.js b/script.js
index f7612779..de9d7e22 100644
--- a/script.js
+++ b/script.js
@@ -19,35 +19,79 @@ function get_uiCurrentTabContent() {
}
var uiUpdateCallbacks = [];
+var uiAfterUpdateCallbacks = [];
var uiLoadedCallbacks = [];
var uiTabChangeCallbacks = [];
var optionsChangedCallbacks = [];
+var uiAfterUpdateTimeout = null;
var uiCurrentTab = null;
+/**
+ * Register callback to be called at each UI update.
+ * The callback receives an array of MutationRecords as an argument.
+ */
function onUiUpdate(callback) {
uiUpdateCallbacks.push(callback);
}
+
+/**
+ * Register callback to be called soon after UI updates.
+ * The callback receives no arguments.
+ *
+ * This is preferred over `onUiUpdate` if you don't need
+ * access to the MutationRecords, as your function will
+ * not be called quite as often.
+ */
+function onAfterUiUpdate(callback) {
+ uiAfterUpdateCallbacks.push(callback);
+}
+
+/**
+ * Register callback to be called when the UI is loaded.
+ * The callback receives no arguments.
+ */
function onUiLoaded(callback) {
uiLoadedCallbacks.push(callback);
}
+
+/**
+ * Register callback to be called when the UI tab is changed.
+ * The callback receives no arguments.
+ */
function onUiTabChange(callback) {
uiTabChangeCallbacks.push(callback);
}
+
+/**
+ * Register callback to be called when the options are changed.
+ * The callback receives no arguments.
+ * @param callback
+ */
function onOptionsChanged(callback) {
optionsChangedCallbacks.push(callback);
}
-function runCallback(x, m) {
- try {
- x(m);
- } catch (e) {
- (console.error || console.log).call(console, e.message, e);
+function executeCallbacks(queue, arg) {
+ for (const callback of queue) {
+ try {
+ callback(arg);
+ } catch (e) {
+ console.error("error running callback", callback, ":", e);
+ }
}
}
-function executeCallbacks(queue, m) {
- queue.forEach(function(x) {
- runCallback(x, m);
- });
+
+/**
+ * Schedule the execution of the callbacks registered with onAfterUiUpdate.
+ * The callbacks are executed after a short while, unless another call to this function
+ * is made before that time. IOW, the callbacks are executed only once, even
+ * when there are multiple mutations observed.
+ */
+function scheduleAfterUiUpdateCallbacks() {
+ clearTimeout(uiAfterUpdateTimeout);
+ uiAfterUpdateTimeout = setTimeout(function() {
+ executeCallbacks(uiAfterUpdateCallbacks);
+ }, 200);
}
var executedOnLoaded = false;
@@ -60,6 +104,7 @@ document.addEventListener("DOMContentLoaded", function() {
}
executeCallbacks(uiUpdateCallbacks, m);
+ scheduleAfterUiUpdateCallbacks();
const newTab = get_uiCurrentTab();
if (newTab && (newTab !== uiCurrentTab)) {
uiCurrentTab = newTab;
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index da820b39..7821cc65 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
@@ -220,6 +220,10 @@ axis_options = [
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
+ AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
+ AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
+ AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
diff --git a/webui.py b/webui.py
index f0ffbbbf..1e3ff061 100644
--- a/webui.py
+++ b/webui.py
@@ -20,7 +20,7 @@ import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors # noqa: F401
+from modules import paths, timer, import_hook, errors, devices # noqa: F401
startup_timer = timer.Timer()
@@ -58,6 +58,7 @@ import modules.sd_hijack
import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
+import modules.sd_unet
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
@@ -291,6 +292,9 @@ def initialize_rest(*, reload_script_modules=False):
modules.sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
+ modules.sd_unet.list_unets()
+ startup_timer.record("scripts list_unets")
+
def load_model():
"""
Accesses shared.sd_model property to load model.
@@ -306,6 +310,8 @@ def initialize_rest(*, reload_script_modules=False):
Thread(target=load_model).start()
+ Thread(target=devices.first_time_calculation).start()
+
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
@@ -381,17 +387,6 @@ def webui():
gradio_auth_creds = list(get_gradio_auth_creds()) or None
- # this restores the missing /docs endpoint
- if launch_api and not hasattr(FastAPI, 'original_setup'):
- # TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged
- def fastapi_setup(self):
- self.docs_url = "/docs"
- self.redoc_url = "/redoc"
- self.original_setup()
-
- FastAPI.original_setup = FastAPI.setup
- FastAPI.setup = fastapi_setup
-
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
@@ -404,6 +399,10 @@ def webui():
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path,
+ app_kwargs={
+ "docs_url": "/docs",
+ "redoc_url": "/redoc",
+ },
)
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
diff --git a/webui.sh b/webui.sh
index ab52ac3b..607557b1 100755
--- a/webui.sh
+++ b/webui.sh
@@ -124,9 +124,12 @@ case "$gpu_info" in
*)
;;
esac
-if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+if ! echo "$gpu_info" | grep -q "NVIDIA";
then
- export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+ if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+ then
+ export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+ fi
fi
for preq in "${GIT}" "${python_cmd}"
@@ -190,7 +193,7 @@ fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
- TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
+ TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}"