aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.eslintrc.js1
-rw-r--r--configs/alt-diffusion-m18-inference.yaml73
-rw-r--r--extensions-builtin/Lora/network_glora.py33
-rw-r--r--extensions-builtin/Lora/networks.py3
-rw-r--r--javascript/dragdrop.js2
-rw-r--r--javascript/edit-attention.js22
-rw-r--r--javascript/extraNetworks.js25
-rw-r--r--javascript/imageviewer.js7
-rw-r--r--javascript/settings.js46
-rw-r--r--javascript/token-counters.js26
-rw-r--r--javascript/ui.js32
-rw-r--r--modules/api/api.py26
-rw-r--r--modules/api/models.py24
-rw-r--r--modules/cmd_args.py7
-rw-r--r--modules/config_states.py3
-rw-r--r--modules/devices.py3
-rw-r--r--modules/generation_parameters_copypaste.py2
-rw-r--r--modules/gitpython_hack.py2
-rw-r--r--modules/hypernetworks/hypernetwork.py4
-rw-r--r--modules/images.py19
-rw-r--r--modules/img2img.py20
-rw-r--r--modules/initialize.py4
-rw-r--r--modules/launch_utils.py4
-rw-r--r--modules/localization.py21
-rw-r--r--modules/options.py2
-rw-r--r--modules/paths.py2
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/processing.py9
-rw-r--r--modules/processing_scripts/seed.py4
-rw-r--r--modules/prompt_parser.py7
-rw-r--r--modules/restart.py4
-rw-r--r--modules/script_callbacks.py6
-rw-r--r--modules/scripts.py6
-rw-r--r--modules/sd_hijack.py21
-rw-r--r--modules/sd_models.py37
-rw-r--r--modules/sd_models_config.py5
-rw-r--r--modules/sd_unet.py4
-rw-r--r--modules/shared_items.py4
-rw-r--r--modules/shared_options.py15
-rw-r--r--modules/shared_state.py2
-rw-r--r--modules/sub_quadratic_attention.py4
-rw-r--r--modules/textual_inversion/textual_inversion.py4
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py33
-rw-r--r--modules/ui_extensions.py2
-rw-r--r--modules/ui_extra_networks.py32
-rw-r--r--modules/ui_gradio_extensions.py6
-rw-r--r--modules/ui_loadsave.py22
-rw-r--r--modules/ui_prompt_styles.py28
-rw-r--r--modules/ui_settings.py35
-rw-r--r--modules/xlmr_m18.py164
-rw-r--r--requirements_versions.txt2
-rw-r--r--scripts/postprocessing_upscale.py2
-rw-r--r--scripts/prompts_from_file.py15
-rw-r--r--scripts/xyz_grid.py7
-rw-r--r--style.css22
-rw-r--r--webui.bat1
-rwxr-xr-xwebui.sh19
58 files changed, 708 insertions, 232 deletions
diff --git a/.eslintrc.js b/.eslintrc.js
index 4777c276..cf839769 100644
--- a/.eslintrc.js
+++ b/.eslintrc.js
@@ -74,6 +74,7 @@ module.exports = {
create_submit_args: "readonly",
restart_reload: "readonly",
updateInput: "readonly",
+ onEdit: "readonly",
//extraNetworks.js
requestGet: "readonly",
popup: "readonly",
diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml
new file mode 100644
index 00000000..41a031d5
--- /dev/null
+++ b/configs/alt-diffusion-m18-inference.yaml
@@ -0,0 +1,73 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: modules.xlmr_m18.BertSeriesModelWithTransformation
+ params:
+ name: "XLMR-Large"
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
new file mode 100644
index 00000000..492d4870
--- /dev/null
+++ b/extensions-builtin/Lora/network_glora.py
@@ -0,0 +1,33 @@
+
+import network
+
+class ModuleTypeGLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
+ return NetworkModuleGLora(net, weights)
+
+ return None
+
+# adapted from https://github.com/KohakuBlueleaf/LyCORIS
+class NetworkModuleGLora(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.w1a = weights.w["a1.weight"]
+ self.w1b = weights.w["b1.weight"]
+ self.w2a = weights.w["a2.weight"]
+ self.w2b = weights.w["b2.weight"]
+
+ def calc_updown(self, orig_weight):
+ w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ output_shape = [w1a.size(0), w1b.size(1)]
+ updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 96f935b2..ddab3c55 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -5,6 +5,7 @@ import re
import lora_patches
import network
import network_lora
+import network_glora
import network_hada
import network_ia3
import network_lokr
@@ -23,6 +24,7 @@ module_types = [
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
]
@@ -418,6 +420,7 @@ def network_forward(module, input, original_forward):
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = ()
self.network_weights_backup = None
+ self.network_bias_backup = None
def network_Linear_forward(self, input):
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
index 5803daea..d680daf5 100644
--- a/javascript/dragdrop.js
+++ b/javascript/dragdrop.js
@@ -119,7 +119,7 @@ window.addEventListener('paste', e => {
}
const firstFreeImageField = visibleImageFields
- .filter(el => el.querySelector('input[type=file]'))?.[0];
+ .filter(el => !el.querySelector('img'))?.[0];
dropReplaceImage(
firstFreeImageField ?
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index 8906c892..794453bf 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -18,22 +18,11 @@ function keyupEditAttention(event) {
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
- let beforeParenClose = before.lastIndexOf(CLOSE);
- while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
- beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
- beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
- }
// Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
- let afterParenOpen = after.indexOf(OPEN);
- while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
- afterParen = after.indexOf(CLOSE, afterParen + 1);
- afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
- }
- if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
@@ -46,9 +35,14 @@ function keyupEditAttention(event) {
function selectCurrentWord() {
if (selectionStart !== selectionEnd) return false;
- const delimiters = opts.keyedit_delimiters + " \r\n\t";
+ const whitespace_delimiters = {"Tab": "\t", "Carriage Return": "\r", "Line Feed": "\n"};
+ let delimiters = opts.keyedit_delimiters;
+
+ for (let i of opts.keyedit_delimiters_whitespace) {
+ delimiters += whitespace_delimiters[i];
+ }
- // seek backward until to find beggining
+ // seek backward to find beginning
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
@@ -92,7 +86,7 @@ function keyupEditAttention(event) {
}
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+ var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end));
if (isNaN(weight)) return;
weight += isPlus ? delta : -delta;
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 493f31af..ac26718f 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -140,14 +140,15 @@ function setupExtraNetworks() {
onUiLoaded(setupExtraNetworks);
-var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/;
-var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
+var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
+var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var m = text.match(re_extranet);
var replaced = false;
var newTextareaText;
if (m) {
+ var extraTextBeforeNet = opts.extra_networks_add_text_separator;
var extraTextAfterNet = m[2];
var partToSearch = m[1];
var foundAtPosition = -1;
@@ -161,8 +162,13 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return found;
});
- if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
- newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
+ if (foundAtPosition >= 0) {
+ if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
+ newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
+ }
+ if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
+ newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition);
+ }
}
} else {
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
@@ -216,27 +222,24 @@ function extraNetworksSearchButton(tabs_id, event) {
var globalPopup = null;
var globalPopupInner = null;
+
function closePopup() {
if (!globalPopup) return;
-
globalPopup.style.display = "none";
}
+
function popup(contents) {
if (!globalPopup) {
globalPopup = document.createElement('div');
- globalPopup.onclick = closePopup;
globalPopup.classList.add('global-popup');
var close = document.createElement('div');
close.classList.add('global-popup-close');
- close.onclick = closePopup;
+ close.addEventListener("click", closePopup);
close.title = "Close";
globalPopup.appendChild(close);
globalPopupInner = document.createElement('div');
- globalPopupInner.onclick = function(event) {
- event.stopPropagation(); return false;
- };
globalPopupInner.classList.add('global-popup-inner');
globalPopup.appendChild(globalPopupInner);
@@ -335,7 +338,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
function extraNetworksRefreshSingleCard(page, tabname, name) {
requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
if (data && data.html) {
- var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function
+ var card = gradioApp().querySelector(`#${tabname}_${page.replace(" ", "_")}_cards > .card[data-name="${name}"]`);
var newDiv = document.createElement('DIV');
newDiv.innerHTML = data.html;
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index c21d396e..e4dae91b 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage");
if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button();
-
- if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
+ let preview = gradioApp().querySelectorAll('.livePreview > img');
+ if (preview.length > 0) {
+ // show preview image if available
+ modalImage.src = preview[preview.length - 1].src;
+ } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') {
const modal = gradioApp().getElementById("lightboxModal");
diff --git a/javascript/settings.js b/javascript/settings.js
new file mode 100644
index 00000000..4e79ec00
--- /dev/null
+++ b/javascript/settings.js
@@ -0,0 +1,46 @@
+let settingsExcludeTabsFromShowAll = {
+ settings_tab_defaults: 1,
+ settings_tab_sysinfo: 1,
+ settings_tab_actions: 1,
+ settings_tab_licenses: 1,
+};
+
+function settingsShowAllTabs() {
+ gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
+ if (settingsExcludeTabsFromShowAll[elem.id]) return;
+
+ elem.style.display = "block";
+ });
+}
+
+function settingsShowOneTab() {
+ gradioApp().querySelector('#settings_show_one_page').click();
+}
+
+onUiLoaded(function() {
+ var edit = gradioApp().querySelector('#settings_search');
+ var editTextarea = gradioApp().querySelector('#settings_search > label > input');
+ var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages');
+ var settings_tabs = gradioApp().querySelector('#settings div');
+
+ onEdit('settingsSearch', editTextarea, 250, function() {
+ var searchText = (editTextarea.value || "").trim().toLowerCase();
+
+ gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) {
+ var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1;
+ elem.style.display = visible ? "" : "none";
+ });
+
+ if (searchText != "") {
+ settingsShowAllTabs();
+ } else {
+ settingsShowOneTab();
+ }
+ });
+
+ settings_tabs.insertBefore(edit, settings_tabs.firstChild);
+ settings_tabs.appendChild(buttonShowAllPages);
+
+
+ buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
+});
diff --git a/javascript/token-counters.js b/javascript/token-counters.js
index 9d81a723..2ecc7d91 100644
--- a/javascript/token-counters.js
+++ b/javascript/token-counters.js
@@ -1,10 +1,9 @@
-let promptTokenCountDebounceTime = 800;
-let promptTokenCountTimeouts = {};
-var promptTokenCountUpdateFunctions = {};
+let promptTokenCountUpdateFunctions = {};
function update_txt2img_tokens(...args) {
// Called from Gradio
update_token_counter("txt2img_token_button");
+ update_token_counter("txt2img_negative_token_button");
if (args.length == 2) {
return args[0];
}
@@ -14,6 +13,7 @@ function update_txt2img_tokens(...args) {
function update_img2img_tokens(...args) {
// Called from Gradio
update_token_counter("img2img_token_button");
+ update_token_counter("img2img_negative_token_button");
if (args.length == 2) {
return args[0];
}
@@ -21,16 +21,7 @@ function update_img2img_tokens(...args) {
}
function update_token_counter(button_id) {
- if (opts.disable_token_counters) {
- return;
- }
- if (promptTokenCountTimeouts[button_id]) {
- clearTimeout(promptTokenCountTimeouts[button_id]);
- }
- promptTokenCountTimeouts[button_id] = setTimeout(
- () => gradioApp().getElementById(button_id)?.click(),
- promptTokenCountDebounceTime,
- );
+ promptTokenCountUpdateFunctions[button_id]?.();
}
@@ -69,10 +60,11 @@ function setupTokenCounting(id, id_counter, id_button) {
prompt.parentElement.insertBefore(counter, prompt);
prompt.parentElement.style.position = "relative";
- promptTokenCountUpdateFunctions[id] = function() {
- update_token_counter(id_button);
- };
- textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
+ var func = onEdit(id, textarea, 800, function() {
+ gradioApp().getElementById(id_button)?.click();
+ });
+ promptTokenCountUpdateFunctions[id] = func;
+ promptTokenCountUpdateFunctions[id_button] = func;
}
function setupTokenCounters() {
diff --git a/javascript/ui.js b/javascript/ui.js
index bedcbf3e..2e262602 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -263,21 +263,6 @@ onAfterUiUpdate(function() {
json_elem.parentElement.style.display = "none";
setupTokenCounters();
-
- var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
- var settings_tabs = gradioApp().querySelector('#settings div');
- if (show_all_pages && settings_tabs) {
- settings_tabs.appendChild(show_all_pages);
- show_all_pages.onclick = function() {
- gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
- if (elem.id == "settings_tab_licenses") {
- return;
- }
-
- elem.style.display = "block";
- });
- };
- }
});
onOptionsChanged(function() {
@@ -366,3 +351,20 @@ function switchWidthHeight(tabname) {
updateInput(height);
return [];
}
+
+
+var onEditTimers = {};
+
+// calls func after afterMs milliseconds has passed since the input elem has beed enited by user
+function onEdit(editId, elem, afterMs, func) {
+ var edited = function() {
+ var existingTimer = onEditTimers[editId];
+ if (existingTimer) clearTimeout(existingTimer);
+
+ onEditTimers[editId] = setTimeout(func, afterMs);
+ };
+
+ elem.addEventListener("input", edited);
+
+ return edited;
+}
diff --git a/modules/api/api.py b/modules/api/api.py
index 844e31ee..905ef9c9 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -29,7 +29,7 @@ from modules.sd_models import unload_model_weights, reload_model_weights, checkp
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
-from typing import Dict, List, Any
+from typing import Any
import piexif
import piexif.helper
from contextlib import closing
@@ -221,15 +221,15 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
- self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
- self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
- self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
- self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
- self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
- self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
- self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
- self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
+ self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
@@ -242,8 +242,8 @@ class Api:
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
- self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
- self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem])
+ self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
+ self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@@ -563,7 +563,7 @@ class Api:
return options
- def set_config(self, req: Dict[str, Any]):
+ def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
diff --git a/modules/api/models.py b/modules/api/models.py
index 94eca97d..a0d80af8 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,12 +1,10 @@
import inspect
from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
-from typing_extensions import Literal
+from typing import Any, Optional, Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser
-from typing import Dict, List
API_NOT_ALLOWED = [
"self",
@@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@@ -168,10 +166,10 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
@@ -233,8 +231,8 @@ FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
- aliases: List[str] = Field(title="Aliases")
- options: Dict[str, str] = Field(title="Options")
+ aliases: list[str] = Field(title="Aliases")
+ options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
@@ -285,8 +283,8 @@ class EmbeddingItem(BaseModel):
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
- loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+ loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
@@ -304,14 +302,14 @@ class ScriptArg(BaseModel):
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
- choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+ choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
- args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+ args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
class ExtensionItem(BaseModel):
name: str = Field(title="Name", description="Extension name")
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index aab62286..4e602a84 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -90,7 +90,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
-parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
+parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
@@ -112,8 +112,9 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
-parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
-parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
+parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
diff --git a/modules/config_states.py b/modules/config_states.py
index b766aef1..651793c7 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -4,7 +4,6 @@ Supports saving and restoring webui and extensions from a known working set of c
import os
import json
-import time
import tqdm
from datetime import datetime
@@ -38,7 +37,7 @@ def list_config_states():
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states:
- timestamp = time.asctime(time.gmtime(cs["created_at"]))
+ timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
name = cs.get("name", "Config")
full_name = f"{name}: {timestamp}"
all_config_states[full_name] = cs
diff --git a/modules/devices.py b/modules/devices.py
index c01f0602..1d4eb563 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -60,7 +60,8 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
+ device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
+ if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d39f2eba..0a606515 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -9,7 +9,7 @@ from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks, processing
from PIL import Image
-re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
+re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py
index e537c1df..b55f0640 100644
--- a/modules/gitpython_hack.py
+++ b/modules/gitpython_hack.py
@@ -23,7 +23,7 @@ class Git(git.Git):
)
return self._parse_object_header(ret)
- def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
+ def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
# Not really streaming, per se; this buffers the entire object in memory.
# Shouldn't be a problem for our use case, since we're only using this for
# object headers (commit objects).
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 70f1cbd2..be3e4648 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0
@@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
diff --git a/modules/images.py b/modules/images.py
index eb644733..daf4eebe 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
})
piexif.insert(exif_bytes, filename)
+ elif extension.lower() == ".gif":
+ image.save(filename, format=image_format, comment=geninfo)
else:
image.save(filename, format=image_format, quality=opts.jpeg_quality)
@@ -661,7 +663,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
- os.replace(temp_file_path, filename_without_extension + extension)
+ filename = filename_without_extension + extension
+ if shared.opts.save_images_replace_action != "Replace":
+ n = 0
+ while os.path.exists(filename):
+ n += 1
+ filename = f"{filename_without_extension}-{n}{extension}"
+ os.replace(temp_file_path, filename)
fullfn_without_extension, extension = os.path.splitext(params.filename)
if hasattr(os, 'statvfs'):
@@ -718,7 +726,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
geninfo = items.pop('parameters', None)
if "exif" in items:
- exif = piexif.load(items["exif"])
+ exif_data = items["exif"]
+ try:
+ exif = piexif.load(exif_data)
+ except OSError:
+ # memory / exif was not valid so piexif tried to read from a file
+ exif = None
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
@@ -728,6 +741,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
if exif_comment:
items['exif comment'] = exif_comment
geninfo = exif_comment
+ elif "comment" in items: # for gif
+ geninfo = items["comment"].decode('utf8', errors="ignore")
for field in IGNORED_INFO_KEYS:
items.pop(field, None)
diff --git a/modules/img2img.py b/modules/img2img.py
index c81c7ab9..52cb577a 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -114,15 +114,19 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
else:
p.override_settings.pop("sd_model_checkpoint", None)
+ if output_dir:
+ p.outpath_samples = output_dir
+ p.override_settings['save_to_dirs'] = False
+ p.override_settings['save_images_replace_action'] = "Add number suffix"
+ if p.n_iter > 1 or p.batch_size > 1:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
+ else:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
+
proc = modules.scripts.scripts_img2img.run(p, *args)
+
if proc is None:
- if output_dir:
- p.outpath_samples = output_dir
- p.override_settings['save_to_dirs'] = False
- if p.n_iter > 1 or p.batch_size > 1:
- p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
- else:
- p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
+ p.override_settings.pop('save_images_replace_action', None)
process_images(p)
@@ -199,7 +203,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.user = request.username
- if shared.cmd_opts.enable_console_prompts:
+ if shared.opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask:
diff --git a/modules/initialize.py b/modules/initialize.py
index f24f7637..ac95fc6f 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -151,8 +151,8 @@ def initialize_rest(*, reload_script_modules=False):
from modules import devices
devices.first_time_calculation()
-
- Thread(target=load_model).start()
+ if not shared.cmd_opts.skip_load_model_at_start:
+ Thread(target=load_model).start()
from modules import shared_items
shared_items.reload_hypernetworks()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 6e54d063..8cdbafa5 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -64,7 +64,7 @@ Use --skip-python-version-check to suppress this warning.
@lru_cache()
def commit_hash():
try:
- return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
+ return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
except Exception:
return "<none>"
@@ -72,7 +72,7 @@ def commit_hash():
@lru_cache()
def git_tag():
try:
- return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
+ return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
try:
diff --git a/modules/localization.py b/modules/localization.py
index c1320288..108f792e 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -14,21 +14,24 @@ def list_localizations(dirname):
if ext.lower() != ".json":
continue
- localizations[fn] = os.path.join(dirname, file)
+ localizations[fn] = [os.path.join(dirname, file)]
for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename)
- localizations[fn] = file.path
+ if fn not in localizations:
+ localizations[fn] = []
+ localizations[fn].append(file.path)
def localization_js(current_localization_name: str) -> str:
- fn = localizations.get(current_localization_name, None)
+ fns = localizations.get(current_localization_name, None)
data = {}
- if fn is not None:
- try:
- with open(fn, "r", encoding="utf8") as file:
- data = json.load(file)
- except Exception:
- errors.report(f"Error loading localization from {fn}", exc_info=True)
+ if fns is not None:
+ for fn in fns:
+ try:
+ with open(fn, "r", encoding="utf8") as file:
+ data.update(json.load(file))
+ except Exception:
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/options.py b/modules/options.py
index 758b1ce5..ab40aff7 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -210,6 +210,8 @@ class Options:
def add_option(self, key, info):
self.data_labels[key] = info
+ if key not in self.data:
+ self.data[key] = info.default
def reorder(self):
"""reorder settings so that all items related to section always go together"""
diff --git a/modules/paths.py b/modules/paths.py
index 25052339..187b9496 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,6 +1,6 @@
import os
import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
import modules.safe # noqa: F401
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 005a9b0a..89131a54 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -8,6 +8,7 @@ import shlex
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
+cwd = os.getcwd()
modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path)
diff --git a/modules/processing.py b/modules/processing.py
index e124e7f0..40598f5c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -142,7 +142,7 @@ class StableDiffusionProcessing:
overlay_images: list = None
eta: float = None
do_not_reload_embeddings: bool = False
- denoising_strength: float = 0
+ denoising_strength: float = None
ddim_discretize: str = None
s_min_uncond: float = None
s_churn: float = None
@@ -533,6 +533,7 @@ class Processed:
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
+ self.version = program_version()
def js(self):
obj = {
@@ -567,6 +568,7 @@ class Processed:
"job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
+ "version": self.version,
}
return json.dumps(obj)
@@ -709,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.before_process(p)
- stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
+ stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@@ -958,6 +960,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.nextjob()
+ if not infotexts:
+ infotexts.append(Processed(p, []).infotext(p, 0))
+
p.color_corrections = None
index_of_first_image = 0
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index 6b6ff987..dc9c2da5 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -29,8 +29,8 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
else:
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
- random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
- reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
+ random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time")
+ reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized")
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 334efeef..ddf4d2dd 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -2,7 +2,6 @@ from __future__ import annotations
import re
from collections import namedtuple
-from typing import List
import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
@@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
- self.schedules: List[ScheduledPromptConditioning] = schedules
+ self.schedules: list[ScheduledPromptConditioning] = schedules
self.weight: float = weight
class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
- self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+ self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
@@ -278,7 +277,7 @@ class DictWithShape(dict):
return self["crossattn"].shape
-def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
+def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
is_dict = isinstance(param, dict)
diff --git a/modules/restart.py b/modules/restart.py
index 18eacaf3..2dd6493b 100644
--- a/modules/restart.py
+++ b/modules/restart.py
@@ -14,7 +14,9 @@ def is_restartable() -> bool:
def restart_program() -> None:
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
- (Path(script_path) / "tmp" / "restart").touch()
+ tmpdir = Path(script_path) / "tmp"
+ tmpdir.mkdir(parents=True, exist_ok=True)
+ (tmpdir / "restart").touch()
stop_program()
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index c99695eb..9ed7ad21 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,7 +1,7 @@
import inspect
import os
from collections import namedtuple
-from typing import Optional, Dict, Any
+from typing import Optional, Any
from fastapi import FastAPI
from gradio import Blocks
@@ -258,7 +258,7 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid')
-def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']:
try:
c.callback(infotext, params)
@@ -449,7 +449,7 @@ def on_infotext_pasted(callback):
"""register a function to be called before applying an infotext.
The callback is called with two arguments:
- infotext: str - raw infotext.
- - result: Dict[str, any] - parsed infotext parameters.
+ - result: dict[str, any] - parsed infotext parameters.
"""
add_callback(callback_map['callbacks_infotext_pasted'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index e8518ad0..5c6e0226 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -491,11 +491,15 @@ class ScriptRunner:
arg_info = api_models.ScriptArg(label=control.label or "")
- for field in ("value", "minimum", "maximum", "step", "choices"):
+ for field in ("value", "minimum", "maximum", "step"):
v = getattr(control, field, None)
if v is not None:
setattr(arg_info, field, v)
+ choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string
+ if choices is not None:
+ arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]
+
api_args.append(arg_info)
script.api_info = api_models.ScriptInfo(
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 592f0055..bc5fbcd3 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -2,14 +2,15 @@ import torch
from torch.nn.functional import silu
from types import MethodType
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
+import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -37,6 +38,8 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
+ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
+sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
@@ -208,7 +211,7 @@ class StableDiffusionModelHijack:
else:
m.cond_stage_model = conditioner
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
@@ -239,10 +242,13 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
- if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
- ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+ if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
+ sd_unet.original_forward = sgm_original_forward
+ else:
+ sd_unet.original_forward = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
@@ -279,7 +285,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+ sd_unet.original_forward = None
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 930d0bee..c8efeedc 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,16 +7,17 @@ import threading
import torch
import re
import safetensors.torch
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
+import numpy as np
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -49,11 +50,12 @@ class CheckpointInfo:
def __init__(self, filename):
self.filename = filename
abspath = os.path.abspath(filename)
+ abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
+ name = abspath.replace(abs_ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
@@ -129,9 +131,12 @@ except Exception:
def setup_model():
+ """called once at startup to do various one-time tasks related to SD models"""
+
os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload()
+ patch_given_betas()
def checkpoint_tiles(use_short=False):
@@ -309,6 +314,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
+ # move to end as latest
+ checkpoints_loaded.move_to_end(checkpoint_info)
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
@@ -350,12 +357,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
- model.load_state_dict(state_dict, strict=False)
- timer.record("apply weights to model")
-
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = state_dict
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
+
+ model.load_state_dict(state_dict, strict=False)
+ timer.record("apply weights to model")
del state_dict
@@ -453,6 +460,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+def patch_given_betas():
+ import ldm.models.diffusion.ddpm
+
+ def patched_register_schedule(*args, **kwargs):
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
+
+ if isinstance(args[1], ListConfig):
+ args = (args[0], np.array(args[1]), *args[2:])
+
+ original_register_schedule(*args, **kwargs)
+
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
+
+
def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"):
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 08dd03f1..deab2f6e 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
-
+config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
"""
@@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
+
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
+ if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
+ return config_alt_diffusion_m18
return config_alt_diffusion
return config_default
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 5525cfbc..6a7bc9e2 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -1,11 +1,11 @@
import torch.nn
-import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None
current_unet = None
+original_forward = None
def list_unets():
@@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
- return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+ return original_forward(self, x, timesteps, context, *args, **kwargs)
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 84d69c8d..b1459f8c 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -44,9 +44,9 @@ def refresh_unet_list():
modules.sd_unet.list_unets()
-def list_checkpoint_tiles():
+def list_checkpoint_tiles(use_short=False):
import modules.sd_models
- return modules.sd_models.checkpoint_tiles()
+ return modules.sd_models.checkpoint_tiles(use_short)
def refresh_checkpoints():
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 00b273fa..ce395302 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -26,7 +26,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
-
+ "save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
@@ -62,6 +62,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
+
+ "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -100,6 +102,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), {
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
+ "enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
@@ -133,7 +136,7 @@ options_templates.update(options_section(('training', "Training"), {
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
@@ -255,12 +258,14 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+ "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"),
+ "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
+ "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
@@ -305,8 +310,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
- 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
- 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
diff --git a/modules/shared_state.py b/modules/shared_state.py
index d272ee5b..a68789cc 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -103,6 +103,7 @@ class State:
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
+ self.time_start = time.time()
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
@@ -114,7 +115,6 @@ class State:
self.skipped = False
self.interrupted = False
self.textinfo = None
- self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index ae4ee4bb..4cb561ef 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -15,7 +15,7 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
-from typing import Optional, NamedTuple, List
+from typing import Optional, NamedTuple
def narrow_trunc(
@@ -97,7 +97,7 @@ def _query_chunk_attention(
)
return summarize_chunk(query, key_chunk, value_chunk)
- chunks: List[AttnChunk] = [
+ chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index aa79dc09..401a0a2a 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -386,7 +386,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
from modules import processing
save_embedding_every = save_embedding_every or 0
@@ -590,7 +590,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 1ee592ad..e4e18ceb 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -3,7 +3,7 @@ from contextlib import closing
import modules.scripts
from modules import processing
from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.shared import opts, cmd_opts
+from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
import gradio as gr
@@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.user = request.username
- if cmd_opts.enable_console_prompts:
+ if shared.opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
with closing(p):
diff --git a/modules/ui.py b/modules/ui.py
index f4028475..bcf39199 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -151,11 +151,15 @@ def connect_clear_prompt(button):
)
-def update_token_counter(text, steps):
+def update_token_counter(text, steps, *, is_positive=True):
try:
text, _ = extra_networks.parse_prompt(text)
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ if is_positive:
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ else:
+ prompt_flat_list = [text]
+
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
@@ -169,6 +173,10 @@ def update_token_counter(text, steps):
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
+def update_negative_prompt_token_counter(text, steps):
+ return update_token_counter(text, steps, is_positive=False)
+
+
class Toprow:
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
@@ -215,9 +223,10 @@ class Toprow:
)
with gr.Row(elem_id=f"{id_part}_tools"):
- self.paste = ToolButton(value=paste_symbol, elem_id="paste")
- self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
+ self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
+ self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt", tooltip="Clear prompt")
+ self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
+ self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False, tooltip="Restore progress")
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
@@ -232,6 +241,7 @@ class Toprow:
)
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
+ self.ui_styles.setup_apply_button(self.apply_styles)
self.prompt_img.change(
fn=modules.images.image_data,
@@ -348,7 +358,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
@@ -533,7 +543,7 @@ def create_ui():
]
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
- toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
@@ -661,8 +671,8 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
- detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
+ detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
@@ -1286,7 +1296,7 @@ def create_ui():
loadsave.setup_ui()
- if os.path.exists(os.path.join(script_path, "notification.mp3")):
+ if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html")
@@ -1338,7 +1348,6 @@ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
def setup_ui_api(app):
from pydantic import BaseModel, Field
- from typing import List
class QuicksettingsHint(BaseModel):
name: str = Field(title="Name of the quicksettings field")
@@ -1347,7 +1356,7 @@ def setup_ui_api(app):
def quicksettings_hint():
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
- app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
+ app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 2e8c1d6d..c0a73b57 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -197,7 +197,7 @@ def update_config_states_table(state_name):
config_state = config_states.all_config_states[state_name]
config_name = config_state.get("name", "Config")
- created_date = time.asctime(time.gmtime(config_state["created_at"]))
+ created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
filepath = config_state.get("filepath", "<unknown>")
try:
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 063bd7b8..59d6ecc6 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,3 +1,4 @@
+import functools
import os.path
import urllib.parse
from pathlib import Path
@@ -15,6 +16,17 @@ from modules.ui_components import ToolButton
extra_pages = []
allowed_dirs = set()
+default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
+
+
+@functools.cache
+def allowed_preview_extensions_with_extra(extra_extensions=None):
+ return set(default_allowed_preview_extensions) | set(extra_extensions or [])
+
+
+def allowed_preview_extensions():
+ return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
+
def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
@@ -33,9 +45,9 @@ def fetch_file(filename: str = ""):
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
- ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
+ ext = os.path.splitext(filename)[1].lower()[1:]
+ if ext not in allowed_preview_extensions():
+ raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -213,9 +225,9 @@ class ExtraNetworksPage:
metadata_button = ""
metadata = item.get("metadata")
if metadata:
- metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
- edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
local_path = ""
filename = item.get("filename", "")
@@ -235,7 +247,7 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never":
return ""
- sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
+ sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
args = {
"background_image": background_image,
@@ -273,11 +285,7 @@ class ExtraNetworksPage:
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
- preview_extensions = ["png", "jpg", "jpeg", "webp"]
- if shared.opts.samples_format not in preview_extensions:
- preview_extensions.append(shared.opts.samples_format)
-
- potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
+ potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
for file in potential_files:
if os.path.isfile(file):
@@ -374,7 +382,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
- button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
+ button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False, tooltip="Invert sort order")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
index b824b113..0d368f8b 100644
--- a/modules/ui_gradio_extensions.py
+++ b/modules/ui_gradio_extensions.py
@@ -2,12 +2,12 @@ import os
import gradio as gr
from modules import localization, shared, scripts
-from modules.paths import script_path, data_path
+from modules.paths import script_path, data_path, cwd
def webpath(fn):
- if fn.startswith(script_path):
- web_path = os.path.relpath(fn, script_path).replace('\\', '/')
+ if fn.startswith(cwd):
+ web_path = os.path.relpath(fn, cwd)
else:
web_path = os.path.abspath(fn)
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
index ec8fa8e8..eb20ff25 100644
--- a/modules/ui_loadsave.py
+++ b/modules/ui_loadsave.py
@@ -4,7 +4,7 @@ import os
import gradio as gr
from modules import errors
-from modules.ui_components import ToolButton
+from modules.ui_components import ToolButton, InputAccordion
def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
@@ -32,8 +32,6 @@ class UiLoadsave:
self.error_loading = True
errors.display(e, "loading settings")
-
-
def add_component(self, path, x):
"""adds component to the registry of tracked components"""
@@ -43,20 +41,24 @@ class UiLoadsave:
key = f"{path}/{field}"
if getattr(obj, 'custom_script_source', None) is not None:
- key = f"customscript/{obj.custom_script_source}/{key}"
+ key = f"customscript/{obj.custom_script_source}/{key}"
if getattr(obj, 'do_not_save_to_config', False):
return
saved_value = self.ui_settings.get(key, None)
+
+ if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value':
+ field = 'open'
+
if saved_value is None:
self.ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value):
pass
else:
- if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
+ if isinstance(obj, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
saved_value = str(saved_value)
- elif isinstance(x, gr.Number) and field == 'value':
+ elif isinstance(obj, gr.Number) and field == 'value':
try:
saved_value = float(saved_value)
except ValueError:
@@ -67,7 +69,7 @@ class UiLoadsave:
init_field(saved_value)
if field == 'value' and key not in self.component_mapping:
- self.component_mapping[key] = x
+ self.component_mapping[key] = obj
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
apply_field(x, 'visible')
@@ -100,6 +102,12 @@ class UiLoadsave:
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+ if type(x) == InputAccordion:
+ if x.accordion.visible:
+ apply_field(x.accordion, 'visible')
+ apply_field(x, 'value')
+ apply_field(x.accordion, 'value')
+
def check_tab_id(tab_id):
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
if type(tab_id) == str:
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
index 85eb3a64..3bcf092f 100644
--- a/modules/ui_prompt_styles.py
+++ b/modules/ui_prompt_styles.py
@@ -4,6 +4,7 @@ from modules import shared, ui_common, ui_components, styles
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
styles_materialize_symbol = '\U0001f4cb' # 📋
+styles_copy_symbol = '\U0001f4dd' # 📝
def select_style(name):
@@ -52,6 +53,8 @@ def refresh_styles():
class UiPromptStyles:
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
self.tabname = tabname
+ self.main_ui_prompt = main_ui_prompt
+ self.main_ui_negative_prompt = main_ui_negative_prompt
with gr.Row(elem_id=f"{tabname}_styles_row"):
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
@@ -61,7 +64,8 @@ class UiPromptStyles:
with gr.Row():
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
- self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+ self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+ self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
with gr.Row():
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
@@ -96,15 +100,21 @@ class UiPromptStyles:
show_progress=False,
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
- self.materialize.click(
- fn=materialize_styles,
- inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
- outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ self.setup_apply_button(self.materialize)
+
+ self.copy.click(
+ fn=lambda p, n: (p, n),
+ inputs=[main_ui_prompt, main_ui_negative_prompt],
+ outputs=[self.prompt, self.neg_prompt],
show_progress=False,
- ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+ )
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
-
-
-
+ def setup_apply_button(self, button):
+ button.click(
+ fn=materialize_styles,
+ inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
+ outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
+ show_progress=False,
+ ).then(fn=None, _js="function(){update_"+self.tabname+"_tokens(); closePopup();}", show_progress=False)
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 8ff9c074..74a3aef3 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -5,6 +5,7 @@ from modules.call_queue import wrap_gradio_call
from modules.shared import opts
from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript
+from concurrent.futures import ThreadPoolExecutor, as_completed
def get_value_for_setting(key):
@@ -63,6 +64,9 @@ class UiSettings:
quicksettings_list = None
quicksettings_names = None
text_settings = None
+ show_all_pages = None
+ show_one_page = None
+ search_input = None
def run_settings(self, *args):
changed = []
@@ -135,7 +139,7 @@ class UiSettings:
gr.Group()
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
current_tab.__enter__()
- current_row = gr.Column(variant='compact')
+ current_row = gr.Column(elem_id=f"column_settings_{elem_id}", variant='compact')
current_row.__enter__()
previous_section = item.section
@@ -175,11 +179,18 @@ class UiSettings:
with gr.Row():
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+ with gr.Row():
+ calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
+ calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
- gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+ self.show_all_pages = gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+ self.show_one_page = gr.Button(value="Show only one page", elem_id="settings_show_one_page", visible=False)
+ self.show_one_page.click(lambda: None)
+
+ self.search_input = gr.Textbox(value="", elem_id="settings_search", max_lines=1, placeholder="Search...", show_label=False)
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
@@ -241,6 +252,21 @@ class UiSettings:
outputs=[sysinfo_check_output],
)
+ def calculate_all_checkpoint_hash_fn(max_thread):
+ checkpoints_list = sd_models.checkpoints_list.values()
+ with ThreadPoolExecutor(max_workers=max_thread) as executor:
+ futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list]
+ completed = 0
+ for _ in as_completed(futures):
+ completed += 1
+ print(f"{completed} / {len(checkpoints_list)} ")
+ print("Finish calculating hash for all checkpoints")
+
+ calculate_all_checkpoint_hash.click(
+ fn=calculate_all_checkpoint_hash_fn,
+ inputs=[calculate_all_checkpoint_hash_threads],
+ )
+
self.interface = settings_interface
def add_quicksettings(self):
@@ -294,3 +320,8 @@ class UiSettings:
outputs=[self.component_dict[k] for k in component_keys],
queue=False,
)
+
+ def search(self, text):
+ print(text)
+
+ return [gr.update(visible=text in (comp.label or "")) for comp in self.components]
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
new file mode 100644
index 00000000..a727e865
--- /dev/null
+++ b/modules/xlmr_m18.py
@@ -0,0 +1,164 @@
+from transformers import BertPreTrainedModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ config_class = BertSeriesConfig
+
+ def __init__(self, config=None, **kargs):
+ # modify initialization for autoloading
+ if config is None:
+ config = XLMRobertaConfig()
+ config.attention_probs_dropout_prob= 0.1
+ config.bos_token_id=0
+ config.eos_token_id=2
+ config.hidden_act='gelu'
+ config.hidden_dropout_prob=0.1
+ config.hidden_size=1024
+ config.initializer_range=0.02
+ config.intermediate_size=4096
+ config.layer_norm_eps=1e-05
+ config.max_position_embeddings=514
+
+ config.num_attention_heads=16
+ config.num_hidden_layers=24
+ config.output_past=True
+ config.pad_token_id=1
+ config.position_embedding_type= "absolute"
+
+ config.type_vocab_size= 1
+ config.use_cache=True
+ config.vocab_size= 250002
+ config.project_dim = 1024
+ config.learn_encoder = False
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+ # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+ # self.pooler = lambda x: x[:,0]
+ # self.post_init()
+
+ self.has_pre_transformation = True
+ if self.has_pre_transformation:
+ self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
+ self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_init()
+
+ def encode(self,c):
+ device = next(self.parameters()).device
+ text = self.tokenizer(c,
+ truncation=True,
+ max_length=77,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+ text["attention_mask"] = torch.tensor(
+ text['attention_mask']).to(device)
+ features = self(**text)
+ return features['projection_state']
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) :
+ r"""
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ outputs = self.roberta(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ # # last module outputs
+ # sequence_output = outputs[0]
+
+
+ # # project every module
+ # sequence_output_ln = self.pre_LN(sequence_output)
+
+ # # pooler
+ # pooler_output = self.pooler(sequence_output_ln)
+ # pooler_output = self.transformation(pooler_output)
+ # projection_state = self.transformation(outputs.last_hidden_state)
+
+ if self.has_pre_transformation:
+ sequence_output2 = outputs["hidden_states"][-2]
+ sequence_output2 = self.pre_LN(sequence_output2)
+ projection_state2 = self.transformation_pre(sequence_output2)
+
+ return {
+ "projection_state": projection_state2,
+ "last_hidden_state": outputs.last_hidden_state,
+ "hidden_states": outputs.hidden_states,
+ "attentions": outputs.attentions,
+ }
+ else:
+ projection_state = self.transformation(outputs.last_hidden_state)
+ return {
+ "projection_state": projection_state,
+ "last_hidden_state": outputs.last_hidden_state,
+ "hidden_states": outputs.hidden_states,
+ "attentions": outputs.attentions,
+ }
+
+
+ # return {
+ # 'pooler_output':pooler_output,
+ # 'last_hidden_state':outputs.last_hidden_state,
+ # 'hidden_states':outputs.hidden_states,
+ # 'attentions':outputs.attentions,
+ # 'projection_state':projection_state,
+ # 'sequence_out': sequence_output
+ # }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+ base_model_prefix = 'roberta'
+ config_class= RobertaSeriesConfig
diff --git a/requirements_versions.txt b/requirements_versions.txt
index f8ae1f38..7d27f2be 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -27,5 +27,5 @@ timm==0.9.2
tomesd==0.1.3
torch
torchdiffeq==0.2.3
-torchsde==0.2.5
+torchsde==0.2.6
transformers==4.30.2
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index edb70ac0..eb42a29e 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -29,7 +29,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
+ upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow():
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 50320d55..ca73b2a5 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -5,11 +5,17 @@ import shlex
import modules.scripts as scripts
import gradio as gr
-from modules import sd_samplers, errors
+from modules import sd_samplers, errors, sd_models
from modules.processing import Processed, process_images
from modules.shared import state
+def process_model_tag(tag):
+ info = sd_models.get_closet_checkpoint_match(tag)
+ assert info is not None, f'Unknown checkpoint: {tag}'
+ return info.name
+
+
def process_string_tag(tag):
return tag
@@ -27,7 +33,7 @@ def process_boolean_tag(tag):
prompt_tags = {
- "sd_model": None,
+ "sd_model": process_model_tag,
"outpath_samples": process_string_tag,
"outpath_grids": process_string_tag,
"prompt_for_display": process_string_tag,
@@ -156,7 +162,10 @@ class Script(scripts.Script):
copy_p = copy.copy(p)
for k, v in args.items():
- setattr(copy_p, k, v)
+ if k == "sd_model":
+ copy_p.override_settings['sd_model_checkpoint'] = v
+ else:
+ setattr(copy_p, k, v)
proc = process_images(copy_p)
images += proc.images
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 939d8605..0dc255bc 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -205,13 +205,14 @@ def csv_string_to_list_strip(data_str):
class AxisOption:
- def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
+ def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):
self.label = label
self.type = type
self.apply = apply
self.format_value = format_value
self.confirm = confirm
self.cost = cost
+ self.prepare = prepare
self.choices = choices
@@ -536,6 +537,8 @@ class Script(scripts.Script):
if opt.choices is not None and not csv_mode:
valslist = vals_dropdown
+ elif opt.prepare is not None:
+ valslist = opt.prepare(vals)
else:
valslist = csv_string_to_list_strip(vals)
@@ -773,6 +776,8 @@ class Script(scripts.Script):
# TODO: See previous comment about intentional data misalignment.
adj_g = g-1 if g > 0 else g
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
+ if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid
+ break
if not include_sub_grids:
# Done with sub-grids, drop all related information:
diff --git a/style.css b/style.css
index fb4e2f1f..115626cd 100644
--- a/style.css
+++ b/style.css
@@ -83,8 +83,10 @@ div.compact{
white-space: nowrap;
}
-.gradio-dropdown ul.options li.item {
- padding: 0.05em 0;
+@media (pointer:fine) {
+ .gradio-dropdown ul.options li.item {
+ padding: 0.05em 0;
+ }
}
.gradio-dropdown ul.options li.item.selected {
@@ -421,6 +423,7 @@ div#extras_scale_to_tab div.form{
#settings > div{
border: none;
margin-left: 10em;
+ padding: 0 var(--spacing-xl);
}
#settings > div.tab-nav{
@@ -435,6 +438,7 @@ div#extras_scale_to_tab div.form{
border: none;
text-align: left;
white-space: initial;
+ padding: 4px;
}
#settings_result{
@@ -581,7 +585,6 @@ table.popup-table .link{
width: 100%;
height: 100%;
overflow: auto;
- background-color: rgba(20, 20, 20, 0.95);
}
.global-popup *{
@@ -590,9 +593,6 @@ table.popup-table .link{
.global-popup-close:before {
content: "×";
-}
-
-.global-popup-close{
position: fixed;
right: 0.25em;
top: 0;
@@ -601,10 +601,20 @@ table.popup-table .link{
font-size: 32pt;
}
+.global-popup-close{
+ position: fixed;
+ left: 0;
+ top: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(20, 20, 20, 0.95);
+}
+
.global-popup-inner{
display: inline-block;
margin: auto;
padding: 2em;
+ z-index: 1001;
}
/* fullpage image viewer */
diff --git a/webui.bat b/webui.bat
index 42e7d517..a630ea4d 100644
--- a/webui.bat
+++ b/webui.bat
@@ -1,6 +1,7 @@
@echo off
if not defined PYTHON (set PYTHON=python)
+if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%")
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
set SD_WEBUI_RESTART=tmp/restart
diff --git a/webui.sh b/webui.sh
index 3d0f87ee..08911469 100755
--- a/webui.sh
+++ b/webui.sh
@@ -4,12 +4,6 @@
# change the variables in webui-user.sh instead #
#################################################
-
-use_venv=1
-if [[ $venv_dir == "-" ]]; then
- use_venv=0
-fi
-
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
@@ -28,6 +22,12 @@ then
source "$SCRIPT_DIR"/webui-user.sh
fi
+# If $venv_dir is "-", then disable venv support
+use_venv=1
+if [[ $venv_dir == "-" ]]; then
+ use_venv=0
+fi
+
# Set defaults
# Install directory without trailing slash
if [[ -z "${install_dir}" ]]
@@ -51,6 +51,8 @@ fi
if [[ -z "${GIT}" ]]
then
export GIT="git"
+else
+ export GIT_PYTHON_GIT_EXECUTABLE="${GIT}"
fi
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
@@ -141,9 +143,8 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
- export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6"
- # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5)
- # so switch to nightly rocm5.6 without explicit versions this time
+ export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6"
+ # Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"