aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.eslintrc.js4
-rw-r--r--CHANGELOG.md96
-rw-r--r--CITATION.cff7
-rw-r--r--README.md7
-rw-r--r--extensions-builtin/Lora/lora_patches.py31
-rw-r--r--extensions-builtin/Lora/network_full.py7
-rw-r--r--extensions-builtin/Lora/networks.py37
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py51
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js128
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/style.css3
-rw-r--r--extensions-builtin/extra-options-section/scripts/extra_options_section.py12
-rw-r--r--extensions-builtin/mobile/javascript/mobile.js6
-rw-r--r--javascript/extraNetworks.js2
-rw-r--r--javascript/localization.js33
-rw-r--r--javascript/progressbar.js67
-rw-r--r--javascript/resizeHandle.js139
-rw-r--r--javascript/ui.js21
-rw-r--r--modules/api/api.py49
-rw-r--r--modules/cache.py5
-rw-r--r--modules/call_queue.py5
-rw-r--r--modules/cmd_args.py5
-rw-r--r--modules/config_states.py16
-rw-r--r--modules/fifo_lock.py37
-rw-r--r--modules/gradio_extensons.py25
-rw-r--r--modules/images.py20
-rw-r--r--modules/img2img.py8
-rw-r--r--modules/initialize_util.py19
-rw-r--r--modules/interrogate.py5
-rw-r--r--modules/launch_utils.py4
-rw-r--r--modules/lowvram.py18
-rw-r--r--modules/options.py19
-rw-r--r--modules/patches.py64
-rw-r--r--modules/processing.py57
-rw-r--r--modules/processing_scripts/refiner.py4
-rw-r--r--modules/processing_scripts/seed.py2
-rw-r--r--modules/progress.py53
-rw-r--r--modules/prompt_parser.py2
-rw-r--r--modules/realesrgan_model.py1
-rw-r--r--modules/rng.py2
-rw-r--r--modules/script_callbacks.py26
-rw-r--r--modules/scripts.py16
-rw-r--r--modules/sd_disable_initialization.py63
-rw-r--r--modules/sd_hijack.py16
-rw-r--r--modules/sd_models.py48
-rw-r--r--modules/sd_models_types.py31
-rw-r--r--modules/sd_samplers_cfg_denoiser.py4
-rw-r--r--modules/sd_samplers_common.py18
-rw-r--r--modules/sd_samplers_kdiffusion.py14
-rw-r--r--modules/sd_samplers_timesteps.py9
-rw-r--r--modules/sd_unet.py2
-rw-r--r--modules/sd_vae.py13
-rw-r--r--modules/shared.py9
-rw-r--r--modules/shared_gradio_themes.py3
-rw-r--r--modules/shared_options.py17
-rw-r--r--modules/shared_state.py2
-rw-r--r--modules/ui.py12
-rw-r--r--modules/ui_common.py2
-rw-r--r--modules/ui_components.py12
-rw-r--r--modules/ui_extensions.py226
-rw-r--r--modules/ui_extra_networks_checkpoints.py3
-rw-r--r--modules/ui_tempdir.py2
-rw-r--r--scripts/xyz_grid.py39
-rw-r--r--style.css85
-rwxr-xr-xwebui.sh7
64 files changed, 1319 insertions, 431 deletions
diff --git a/.eslintrc.js b/.eslintrc.js
index e3b4fb76..4777c276 100644
--- a/.eslintrc.js
+++ b/.eslintrc.js
@@ -90,6 +90,8 @@ module.exports = {
// localStorage.js
localSet: "readonly",
localGet: "readonly",
- localRemove: "readonly"
+ localRemove: "readonly",
+ // resizeHandle.js
+ setupResizeHandle: "writable"
}
};
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b18c6867..ea1c8b16 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,99 @@
+## 1.6.0
+
+### Features:
+ * refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371)
+ * add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards
+ * add style editor dialog
+ * hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181))
+ * option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227))
+ * new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
+ * rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
+ * makes all of them work with img2img
+ * makes prompt composition posssible (AND)
+ * makes them available for SDXL
+ * always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
+ * use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958))
+ * textual inversion inference support for SDXL
+ * extra networks UI: show metadata for SD checkpoints
+ * checkpoint merger: add metadata support
+ * prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177))
+ * VAE: allow selecting own VAE for each checkpoint (in user metadata editor)
+ * VAE: add selected VAE to infotext
+ * options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551))
+
+
+### Minor:
+ * img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515))
+ * postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479))
+ * XYZ: in the axis labels, remove pathnames from model filenames
+ * XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298))
+ * XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491))
+ * add gradio version warning
+ * sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297))
+ * use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326))
+ * move some settings to their own section: img2img, VAE
+ * add checkbox to show/hide dirs for extra networks
+ * Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311))
+ * gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355))
+ * sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521))
+ * update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352))
+ * option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338))
+ * enable cond cache by default
+ * git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230))
+ * allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379))
+ * automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254))
+ * put commonly used samplers on top, make DPM++ 2M Karras the default choice
+ * zoom and pan: option to auto-expand a wide image ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413))
+ * option to cache Lora networks in memory
+ * rework hires fix UI to use accordion
+ * face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back
+ * change quicksettings items to have variable width
+ * Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503))
+ * Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
+ * support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510))
+ * Add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564))
+
+### Extensions and API:
+ * gradio 3.39
+ * also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd
+ * support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world')
+ * properly clear the total console progressbar when using txt2img and img2img from API
+ * add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294))
+ * shared.py and webui.py split into many files
+ * add --loglevel commandline argument for logging
+ * add a custom UI element that combines accordion and checkbox
+ * avoid importing gradio in tests because it spams warnings
+ * put infotext label for setting into OptionInfo definition rather than in a separate list
+ * make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470))
+ * option to make scripts UI without gr.Group
+ * add a way for scripts to register a callback for before/after just a single component's creation
+ * use dataclass for StableDiffusionProcessing
+
+### Bug Fixes:
+ * Don't crash if out of local storage quota for javascriot localStorage
+ * fix memory leak when generation fails
+ * XYZ plot do not fail if an exception occurs
+ * update doggettx cross attention optimization to not use an unreasonable amount of memory in some edge cases -- suggestion by MorkTheOrk
+ * fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269))
+ * localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307))
+ * fix sdxl model invalid configuration after the hijack
+ * correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304))
+ * open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318))
+ * prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319))
+ * add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327))
+ * fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387))
+ * fix options in main UI misbehaving when there's just one element
+ * make it possible to use a sampler from infotext even if it's hidden in the dropdown
+ * fix styles missing from the prompt in infotext when making a grid of batch of multiplie images
+ * prevent bogus progress output in console when calculating hires fix dimensions
+ * fix --use-textbox-seed
+ * fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466))
+ * properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463))
+ * MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526))
+ * add second_order to samplers that mistakenly didn't have it
+ * when refreshing cards in extra networks UI, do not discard user's custom resolution
+ * fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509))
+
## 1.5.1
### Minor:
diff --git a/CITATION.cff b/CITATION.cff
new file mode 100644
index 00000000..2c781aff
--- /dev/null
+++ b/CITATION.cff
@@ -0,0 +1,7 @@
+cff-version: 1.2.0
+message: "If you use this software, please cite it as below."
+authors:
+ - given-names: AUTOMATIC1111
+title: "Stable Diffusion Web UI"
+date-released: 2022-08-22
+url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui"
diff --git a/README.md b/README.md
index 940176d0..4e083440 100644
--- a/README.md
+++ b/README.md
@@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Clip skip
- Hypernetworks
- Loras (same as Hypernetworks but more pretty)
-- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
+- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar
- API
@@ -93,7 +93,10 @@ A browser interface based on Gradio library for Stable Diffusion.
- Reorder elements in the UI from settings screen
## Installation and Running
-Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
+Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
+- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
+- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
+- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
Alternatively, use online services (like Google Colab):
diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py
new file mode 100644
index 00000000..b394d8e9
--- /dev/null
+++ b/extensions-builtin/Lora/lora_patches.py
@@ -0,0 +1,31 @@
+import torch
+
+import networks
+from modules import patches
+
+
+class LoraPatches:
+ def __init__(self):
+ self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
+ self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
+ self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
+ self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
+ self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
+ self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
+ self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
+ self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
+ self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
+ self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
+
+ def undo(self):
+ self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
+ self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
+ self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
+ self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
+ self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
+ self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
+ self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
+ self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
+ self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
+ self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
+
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
index 109b4c2c..bf6930e9 100644
--- a/extensions-builtin/Lora/network_full.py
+++ b/extensions-builtin/Lora/network_full.py
@@ -14,9 +14,14 @@ class NetworkModuleFull(network.NetworkModule):
super().__init__(net, weights)
self.weight = weights.w.get("diff")
+ self.ex_bias = weights.w.get("diff_b")
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ if self.ex_bias is not None:
+ ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ else:
+ ex_bias = None
- return self.finalize_updown(updown, orig_weight, output_shape)
+ return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 22fdff4a..96f935b2 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -2,6 +2,7 @@ import logging
import os
import re
+import lora_patches
import network
import network_lora
import network_hada
@@ -303,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
weights_backup = getattr(self, "network_weights_backup", None)
- if weights_backup is None:
+ if weights_backup is None and wanted_names != ():
+ if current_names != ():
+ raise RuntimeError("no backup weights found and current weights are not unchanged")
+
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
@@ -418,74 +422,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
def network_Linear_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Linear_forward_before_network)
+ return network_forward(self, input, originals.Linear_forward)
network_apply_weights(self)
- return torch.nn.Linear_forward_before_network(self, input)
+ return originals.Linear_forward(self, input)
def network_Linear_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Linear_load_state_dict(self, *args, **kwargs)
def network_Conv2d_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
+ return network_forward(self, input, originals.Conv2d_forward)
network_apply_weights(self)
- return torch.nn.Conv2d_forward_before_network(self, input)
+ return originals.Conv2d_forward(self, input)
def network_Conv2d_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Conv2d_load_state_dict(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
+ return network_forward(self, input, originals.GroupNorm_forward)
network_apply_weights(self)
- return torch.nn.GroupNorm_forward_before_network(self, input)
+ return originals.GroupNorm_forward(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
+ return network_forward(self, input, originals.LayerNorm_forward)
network_apply_weights(self)
- return torch.nn.LayerNorm_forward_before_network(self, input)
+ return originals.LayerNorm_forward(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
- return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_forward(self, *args, **kwargs)
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
def list_available_networks():
@@ -552,6 +556,9 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
+
+originals: lora_patches.LoraPatches = None
+
extra_network_lora = None
available_networks = {}
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 4c6e774a..ef23968c 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,23 +1,19 @@
import re
-import torch
import gradio as gr
from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
+import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
+
def unload():
- torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
- torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
- torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
- torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
- torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
- torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
+ networks.originals.undo()
def before_ui():
@@ -28,46 +24,7 @@ def before_ui():
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
-if not hasattr(torch.nn, 'Linear_forward_before_network'):
- torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
-
-if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
- torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
-
-if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
- torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
-
-if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
- torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
-
-if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
- torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
-
-if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
- torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
- torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
-
-if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
- torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
- torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
-
-if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
- torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
-
-torch.nn.Linear.forward = networks.network_Linear_forward
-torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
-torch.nn.Conv2d.forward = networks.network_Conv2d_forward
-torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
-torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
-torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
-torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
-torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
-torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
-torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
+networks.originals = lora_patches.LoraPatches()
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
index 72c8ba87..23423891 100644
--- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
+++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
@@ -15,6 +15,19 @@ onUiLoaded(async() => {
// Helper functions
// Get active tab
+
+ /**
+ * Waits for an element to be present in the DOM.
+ */
+ const waitForElement = (id) => new Promise(resolve => {
+ const checkForElement = () => {
+ const element = document.querySelector(id);
+ if (element) return resolve(element);
+ setTimeout(checkForElement, 100);
+ };
+ checkForElement();
+ });
+
function getActiveTab(elements, all = false) {
const tabs = elements.img2imgTabs.querySelectorAll("button");
@@ -35,7 +48,7 @@ onUiLoaded(async() => {
// Wait until opts loaded
async function waitForOpts() {
- for (;;) {
+ for (; ;) {
if (window.opts && Object.keys(window.opts).length) {
return window.opts;
}
@@ -256,7 +269,7 @@ onUiLoaded(async() => {
input?.addEventListener("input", () => restoreImgRedMask(elements));
}
- function applyZoomAndPan(elemId) {
+ function applyZoomAndPan(elemId, isExtension = true) {
const targetElement = gradioApp().querySelector(elemId);
if (!targetElement) {
@@ -368,6 +381,10 @@ onUiLoaded(async() => {
panY: 0
};
+ if (isExtension) {
+ targetElement.style.overflow = "hidden";
+ }
+
fixCanvas();
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
@@ -383,8 +400,22 @@ onUiLoaded(async() => {
closeBtn.addEventListener("click", resetZoom);
}
+ if (canvas && isExtension) {
+ const parentElement = targetElement.closest('[id^="component-"]');
+ if (
+ canvas &&
+ parseFloat(canvas.style.width) > parentElement.offsetWidth &&
+ parseFloat(targetElement.style.width) > parentElement.offsetWidth
+ ) {
+ fitToElement();
+ return;
+ }
+
+ }
+
if (
canvas &&
+ !isExtension &&
parseFloat(canvas.style.width) > 865 &&
parseFloat(targetElement.style.width) > 865
) {
@@ -393,9 +424,6 @@ onUiLoaded(async() => {
}
targetElement.style.width = "";
- if (canvas) {
- targetElement.style.height = canvas.style.height;
- }
}
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
@@ -462,6 +490,10 @@ onUiLoaded(async() => {
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
toggleOverlap("on");
+ if (isExtension) {
+ targetElement.style.overflow = "visible";
+ }
+
return newZoomLevel;
}
@@ -484,7 +516,7 @@ onUiLoaded(async() => {
fullScreenMode = false;
elemData[elemId].zoomLevel = updateZoom(
elemData[elemId].zoomLevel +
- (operation === "+" ? delta : -delta),
+ (operation === "+" ? delta : -delta),
zoomPosX - targetElement.getBoundingClientRect().left,
zoomPosY - targetElement.getBoundingClientRect().top
);
@@ -501,10 +533,19 @@ onUiLoaded(async() => {
//Reset Zoom
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
+ let parentElement;
+
+ if (isExtension) {
+ parentElement = targetElement.closest('[id^="component-"]');
+ } else {
+ parentElement = targetElement.parentElement;
+ }
+
+
// Get element and screen dimensions
const elementWidth = targetElement.offsetWidth;
const elementHeight = targetElement.offsetHeight;
- const parentElement = targetElement.parentElement;
+
const screenWidth = parentElement.clientWidth;
const screenHeight = parentElement.clientHeight;
@@ -555,10 +596,15 @@ onUiLoaded(async() => {
`${elemId} canvas[key="interface"]`
);
+ if (isExtension) {
+ targetElement.style.overflow = "visible";
+ }
+
+
if (!canvas) return;
- if (canvas.offsetWidth > 862) {
- targetElement.style.width = canvas.offsetWidth + "px";
+ if (canvas.offsetWidth > 862 || isExtension) {
+ targetElement.style.width = (canvas.offsetWidth + 2) + "px";
}
if (fullScreenMode) {
@@ -667,9 +713,7 @@ onUiLoaded(async() => {
targetElement.isExpanded = false;
function autoExpand() {
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
- const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch;
-
- if (canvas && isMainTab) {
+ if (canvas) {
if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
targetElement.style.visibility = "hidden";
setTimeout(() => {
@@ -808,6 +852,11 @@ onUiLoaded(async() => {
if (isMoving && elemId === activeElement) {
updatePanPosition(e.movementX, e.movementY);
targetElement.style.pointerEvents = "none";
+
+ if (isExtension) {
+ targetElement.style.overflow = "visible";
+ }
+
} else {
targetElement.style.pointerEvents = "auto";
}
@@ -821,10 +870,57 @@ onUiLoaded(async() => {
gradioApp().addEventListener("mousemove", handleMoveByKey);
}
- applyZoomAndPan(elementIDs.sketch);
- applyZoomAndPan(elementIDs.inpaint);
- applyZoomAndPan(elementIDs.inpaintSketch);
+ applyZoomAndPan(elementIDs.sketch, false);
+ applyZoomAndPan(elementIDs.inpaint, false);
+ applyZoomAndPan(elementIDs.inpaintSketch, false);
// Make the function global so that other extensions can take advantage of this solution
- window.applyZoomAndPan = applyZoomAndPan;
+ const applyZoomAndPanIntegration = async(id, elementIDs) => {
+ const mainEl = document.querySelector(id);
+ if (id.toLocaleLowerCase() === "none") {
+ for (const elementID of elementIDs) {
+ const el = await waitForElement(elementID);
+ if (!el) break;
+ applyZoomAndPan(elementID);
+ }
+ return;
+ }
+
+ if (!mainEl) return;
+ mainEl.addEventListener("click", async() => {
+ for (const elementID of elementIDs) {
+ const el = await waitForElement(elementID);
+ if (!el) break;
+ applyZoomAndPan(elementID);
+ }
+ }, {once: true});
+ };
+
+ window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image")
+
+ window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension
+
+ /*
+ The function `applyZoomAndPanIntegration` takes two arguments:
+
+ 1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click.
+ If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event.
+
+ 2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument.
+ If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event.
+
+ Example usage:
+ applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
+ In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet".
+ */
+
+ // More examples
+ // Add integration with ControlNet txt2img One TAB
+ // applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
+
+ // Add integration with ControlNet txt2img Tabs
+ // applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`));
+
+ // Add integration with Inpaint Anything
+ // applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]);
});
diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css
index 6bcc9570..5d8054e6 100644
--- a/extensions-builtin/canvas-zoom-and-pan/style.css
+++ b/extensions-builtin/canvas-zoom-and-pan/style.css
@@ -61,3 +61,6 @@
to {opacity: 1;}
}
+.styler {
+ overflow:inherit !important;
+} \ No newline at end of file
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
index 588b64d2..983f87ff 100644
--- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py
+++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
@@ -22,22 +22,23 @@ class ExtraOptionsSection(scripts.Script):
self.comps = []
self.setting_names = []
self.infotext_fields = []
+ extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
- with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group():
+ with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
- row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols)
+ row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
for row in range(row_count):
with gr.Row():
for col in range(shared.opts.extra_options_cols):
index = row * shared.opts.extra_options_cols + col
- if index >= len(shared.opts.extra_options):
+ if index >= len(extra_options):
break
- setting_name = shared.opts.extra_options[index]
+ setting_name = extra_options[index]
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
@@ -64,7 +65,8 @@ class ExtraOptionsSection(scripts.Script):
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
- "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
+ "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
+ "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
}))
diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js
index 12cae4b7..652f07ac 100644
--- a/extensions-builtin/mobile/javascript/mobile.js
+++ b/extensions-builtin/mobile/javascript/mobile.js
@@ -20,7 +20,13 @@ function reportWindowSize() {
var button = gradioApp().getElementById(tab + '_generate_box');
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
target.insertBefore(button, target.firstElementChild);
+
+ gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile);
}
}
window.addEventListener("resize", reportWindowSize);
+
+onUiLoaded(function() {
+ reportWindowSize();
+});
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 897ebeba..3bc723d3 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -332,7 +332,7 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
newDiv.innerHTML = data.html;
var newCard = newDiv.firstElementChild;
- newCard.style = '';
+ newCard.style.display = '';
card.parentElement.insertBefore(newCard, card);
card.parentElement.removeChild(card);
}
diff --git a/javascript/localization.js b/javascript/localization.js
index 0c9032f9..8f00c186 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -107,12 +107,41 @@ function processNode(node) {
});
}
+function localizeWholePage() {
+ processNode(gradioApp());
+
+ function elem(comp) {
+ var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id;
+ return gradioApp().getElementById(elem_id);
+ }
+
+ for (var comp of window.gradio_config.components) {
+ if (comp.props.webui_tooltip) {
+ let e = elem(comp);
+
+ let tl = e ? getTranslation(e.title) : undefined;
+ if (tl !== undefined) {
+ e.title = tl;
+ }
+ }
+ if (comp.props.placeholder) {
+ let e = elem(comp);
+ let textbox = e ? e.querySelector('[placeholder]') : null;
+
+ let tl = textbox ? getTranslation(textbox.placeholder) : undefined;
+ if (tl !== undefined) {
+ textbox.placeholder = tl;
+ }
+ }
+ }
+}
+
function dumpTranslations() {
if (!hasLocalization()) {
// If we don't have any localization,
// we will not have traversed the app to find
// original_lines, so do that now.
- processNode(gradioApp());
+ localizeWholePage();
}
var dumped = {};
if (localization.rtl) {
@@ -154,7 +183,7 @@ document.addEventListener("DOMContentLoaded", function() {
});
});
- processNode(gradioApp());
+ localizeWholePage();
if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 29299787..77761495 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -69,7 +69,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var dateStart = new Date();
var wasEverActive = false;
var parentProgressbar = progressbarContainer.parentNode;
- var parentGallery = gallery ? gallery.parentNode : null;
var divProgress = document.createElement('div');
divProgress.className = 'progressDiv';
@@ -80,32 +79,26 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
divProgress.appendChild(divInner);
parentProgressbar.insertBefore(divProgress, progressbarContainer);
- if (parentGallery) {
- var livePreview = document.createElement('div');
- livePreview.className = 'livePreview';
- parentGallery.insertBefore(livePreview, gallery);
- }
+ var livePreview = null;
var removeProgressBar = function() {
+ if (!divProgress) return;
+
setTitle("");
parentProgressbar.removeChild(divProgress);
- if (parentGallery) parentGallery.removeChild(livePreview);
+ if (gallery && livePreview) gallery.removeChild(livePreview);
atEnd();
+
+ divProgress = null;
};
- var fun = function(id_task, id_live_preview) {
- request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
+ var funProgress = function(id_task) {
+ request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) {
if (res.completed) {
removeProgressBar();
return;
}
- var rect = progressbarContainer.getBoundingClientRect();
-
- if (rect.width) {
- divProgress.style.width = rect.width + "px";
- }
-
let progressText = "";
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
@@ -119,7 +112,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
progressText += " ETA: " + formatTime(res.eta);
}
-
setTitle(progressText);
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
@@ -142,16 +134,33 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
return;
}
+ if (onProgress) {
+ onProgress(res);
+ }
- if (res.live_preview && gallery) {
- rect = gallery.getBoundingClientRect();
- if (rect.width) {
- livePreview.style.width = rect.width + "px";
- livePreview.style.height = rect.height + "px";
- }
+ setTimeout(() => {
+ funProgress(id_task, res.id_live_preview);
+ }, opts.live_preview_refresh_period || 500);
+ }, function() {
+ removeProgressBar();
+ });
+ };
+ var funLivePreview = function(id_task, id_live_preview) {
+ request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
+ if (!divProgress) {
+ return;
+ }
+
+ if (res.live_preview && gallery) {
var img = new Image();
img.onload = function() {
+ if (!livePreview) {
+ livePreview = document.createElement('div');
+ livePreview.className = 'livePreview';
+ gallery.insertBefore(livePreview, gallery.firstElementChild);
+ }
+
livePreview.appendChild(img);
if (livePreview.childElementCount > 2) {
livePreview.removeChild(livePreview.firstElementChild);
@@ -160,18 +169,18 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
img.src = res.live_preview;
}
-
- if (onProgress) {
- onProgress(res);
- }
-
setTimeout(() => {
- fun(id_task, res.id_live_preview);
+ funLivePreview(id_task, res.id_live_preview);
}, opts.live_preview_refresh_period || 500);
}, function() {
removeProgressBar();
});
};
- fun(id_task, 0);
+ funProgress(id_task, 0);
+
+ if (gallery) {
+ funLivePreview(id_task, 0);
+ }
+
}
diff --git a/javascript/resizeHandle.js b/javascript/resizeHandle.js
new file mode 100644
index 00000000..2fd3c4d2
--- /dev/null
+++ b/javascript/resizeHandle.js
@@ -0,0 +1,139 @@
+(function() {
+ const GRADIO_MIN_WIDTH = 320;
+ const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr';
+ const PAD = 16;
+ const DEBOUNCE_TIME = 100;
+
+ const R = {
+ tracking: false,
+ parent: null,
+ parentWidth: null,
+ leftCol: null,
+ leftColStartWidth: null,
+ screenX: null,
+ };
+
+ let resizeTimer;
+ let parents = [];
+
+ function setLeftColGridTemplate(el, width) {
+ el.style.gridTemplateColumns = `${width}px 16px 1fr`;
+ }
+
+ function displayResizeHandle(parent) {
+ if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {
+ parent.style.display = 'flex';
+ if (R.handle != null) {
+ R.handle.style.opacity = '0';
+ }
+ return false;
+ } else {
+ parent.style.display = 'grid';
+ if (R.handle != null) {
+ R.handle.style.opacity = '100';
+ }
+ return true;
+ }
+ }
+
+ function afterResize(parent) {
+ if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) {
+ const oldParentWidth = R.parentWidth;
+ const newParentWidth = parent.offsetWidth;
+ const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);
+
+ const ratio = newParentWidth / oldParentWidth;
+
+ const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH);
+ setLeftColGridTemplate(parent, newWidthL);
+
+ R.parentWidth = newParentWidth;
+ }
+ }
+
+ function setup(parent) {
+ const leftCol = parent.firstElementChild;
+ const rightCol = parent.lastElementChild;
+
+ parents.push(parent);
+
+ parent.style.display = 'grid';
+ parent.style.gap = '0';
+ parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
+
+ const resizeHandle = document.createElement('div');
+ resizeHandle.classList.add('resize-handle');
+ parent.insertBefore(resizeHandle, rightCol);
+
+ resizeHandle.addEventListener('mousedown', (evt) => {
+ if (evt.button !== 0) return;
+
+ evt.preventDefault();
+ evt.stopPropagation();
+
+ document.body.classList.add('resizing');
+
+ R.tracking = true;
+ R.parent = parent;
+ R.parentWidth = parent.offsetWidth;
+ R.handle = resizeHandle;
+ R.leftCol = leftCol;
+ R.leftColStartWidth = leftCol.offsetWidth;
+ R.screenX = evt.screenX;
+ });
+
+ resizeHandle.addEventListener('dblclick', (evt) => {
+ evt.preventDefault();
+ evt.stopPropagation();
+
+ parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
+ });
+
+ afterResize(parent);
+ }
+
+ window.addEventListener('mousemove', (evt) => {
+ if (evt.button !== 0) return;
+
+ if (R.tracking) {
+ evt.preventDefault();
+ evt.stopPropagation();
+
+ const delta = R.screenX - evt.screenX;
+ const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH);
+ setLeftColGridTemplate(R.parent, leftColWidth);
+ }
+ });
+
+ window.addEventListener('mouseup', (evt) => {
+ if (evt.button !== 0) return;
+
+ if (R.tracking) {
+ evt.preventDefault();
+ evt.stopPropagation();
+
+ R.tracking = false;
+
+ document.body.classList.remove('resizing');
+ }
+ });
+
+
+ window.addEventListener('resize', () => {
+ clearTimeout(resizeTimer);
+
+ resizeTimer = setTimeout(function() {
+ for (const parent of parents) {
+ afterResize(parent);
+ }
+ }, DEBOUNCE_TIME);
+ });
+
+ setupResizeHandle = setup;
+})();
+
+onUiLoaded(function() {
+ for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) {
+ setupResizeHandle(elem);
+ }
+});
diff --git a/javascript/ui.js b/javascript/ui.js
index bade3089..bedcbf3e 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -19,28 +19,11 @@ function all_gallery_buttons() {
}
function selected_gallery_button() {
- var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
- var visibleCurrentButton = null;
- allCurrentButtons.forEach(function(elem) {
- if (elem.parentElement.offsetParent) {
- visibleCurrentButton = elem;
- }
- });
- return visibleCurrentButton;
+ return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null;
}
function selected_gallery_index() {
- var buttons = all_gallery_buttons();
- var button = selected_gallery_button();
-
- var result = -1;
- buttons.forEach(function(v, i) {
- if (v == button) {
- result = i;
- }
- });
-
- return result;
+ return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected'));
}
function extract_image_from_gallery(gallery) {
diff --git a/modules/api/api.py b/modules/api/api.py
index 908c4514..e6edffe7 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -4,6 +4,8 @@ import os
import time
import datetime
import uvicorn
+import ipaddress
+import requests
import gradio as gr
from threading import Lock
from io import BytesIO
@@ -23,8 +25,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
-from modules.sd_vae import vae_dict
+from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -56,7 +57,41 @@ def setUpscalers(req: dict):
return reqDict
+def verify_url(url):
+ """Returns True if the url refers to a global resource."""
+
+ import socket
+ from urllib.parse import urlparse
+ try:
+ parsed_url = urlparse(url)
+ domain_name = parsed_url.netloc
+ host = socket.gethostbyname_ex(domain_name)
+ for ip in host[2]:
+ ip_addr = ipaddress.ip_address(ip)
+ if not ip_addr.is_global:
+ return False
+ except Exception:
+ return False
+
+ return True
+
+
def decode_base64_to_image(encoding):
+ if encoding.startswith("http://") or encoding.startswith("https://"):
+ if not opts.api_enable_requests:
+ raise HTTPException(status_code=500, detail="Requests not allowed")
+
+ if opts.api_forbid_local_requests and not verify_url(encoding):
+ raise HTTPException(status_code=500, detail="Request to local resource not allowed")
+
+ headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
+ response = requests.get(encoding, timeout=30, headers=headers)
+ try:
+ image = Image.open(BytesIO(response.content))
+ return image
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Invalid image url") from e
+
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
@@ -330,6 +365,7 @@ class Api:
with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+ p.is_api = True
p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
@@ -390,6 +426,7 @@ class Api:
with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.is_api = True
p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
@@ -533,7 +570,7 @@ class Api:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():
- shared.opts.set(k, v)
+ shared.opts.set(k, v, is_api=True)
shared.opts.save(shared.config_filename)
return
@@ -565,10 +602,12 @@ class Api:
]
def get_sd_models(self):
- return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+ import modules.sd_models as sd_models
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
def get_sd_vaes(self):
- return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
+ import modules.sd_vae as sd_vae
+ return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/cache.py b/modules/cache.py
index a7cd3aeb..ff26a213 100644
--- a/modules/cache.py
+++ b/modules/cache.py
@@ -30,9 +30,12 @@ def dump_cache():
time.sleep(1)
with cache_lock:
- with open(cache_filename, "w", encoding="utf8") as file:
+ cache_filename_tmp = cache_filename + "-"
+ with open(cache_filename_tmp, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4)
+ os.replace(cache_filename_tmp, cache_filename)
+
dump_cache_after = None
dump_cache_thread = None
diff --git a/modules/call_queue.py b/modules/call_queue.py
index f2eb17d6..ddf0d573 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,11 +1,10 @@
from functools import wraps
import html
-import threading
import time
-from modules import shared, progress, errors, devices
+from modules import shared, progress, errors, devices, fifo_lock
-queue_lock = threading.Lock()
+queue_lock = fifo_lock.FIFOLock()
def wrap_queued_call(func):
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index b0a11538..f0f361bd 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -35,9 +35,10 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
+parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
-parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
+parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
@@ -81,7 +82,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
-parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
+parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
diff --git a/modules/config_states.py b/modules/config_states.py
index 6f1ab53f..b766aef1 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -8,14 +8,12 @@ import time
import tqdm
from datetime import datetime
-from collections import OrderedDict
import git
from modules import shared, extensions, errors
from modules.paths_internal import script_path, config_states_dir
-
-all_config_states = OrderedDict()
+all_config_states = {}
def list_config_states():
@@ -28,10 +26,14 @@ def list_config_states():
for filename in os.listdir(config_states_dir):
if filename.endswith(".json"):
path = os.path.join(config_states_dir, filename)
- with open(path, "r", encoding="utf-8") as f:
- j = json.load(f)
- j["filepath"] = path
- config_states.append(j)
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ j = json.load(f)
+ assert "created_at" in j, '"created_at" does not exist'
+ j["filepath"] = path
+ config_states.append(j)
+ except Exception as e:
+ print(f'[ERROR]: Config states {path}, {e}')
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
diff --git a/modules/fifo_lock.py b/modules/fifo_lock.py
new file mode 100644
index 00000000..c35b3ae2
--- /dev/null
+++ b/modules/fifo_lock.py
@@ -0,0 +1,37 @@
+import threading
+import collections
+
+
+# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
+class FIFOLock(object):
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._inner_lock = threading.Lock()
+ self._pending_threads = collections.deque()
+
+ def acquire(self, blocking=True):
+ with self._inner_lock:
+ lock_acquired = self._lock.acquire(False)
+ if lock_acquired:
+ return True
+ elif not blocking:
+ return False
+
+ release_event = threading.Event()
+ self._pending_threads.append(release_event)
+
+ release_event.wait()
+ return self._lock.acquire()
+
+ def release(self):
+ with self._inner_lock:
+ if self._pending_threads:
+ release_event = self._pending_threads.popleft()
+ release_event.set()
+
+ self._lock.release()
+
+ __enter__ = acquire
+
+ def __exit__(self, t, v, tb):
+ self.release()
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
index 77c34c8b..e6b6835a 100644
--- a/modules/gradio_extensons.py
+++ b/modules/gradio_extensons.py
@@ -1,6 +1,7 @@
import gradio as gr
-from modules import scripts, ui_tempdir
+from modules import scripts, ui_tempdir, patches
+
def add_classes_to_gradio_component(comp):
"""
@@ -40,6 +41,8 @@ def Block_get_config(self):
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip
+ config.pop('example_inputs', None)
+
return config
@@ -51,12 +54,20 @@ def BlockContext_init(self, *args, **kwargs):
return res
-original_IOComponent_init = gr.components.IOComponent.__init__
-original_Block_get_config = gr.blocks.Block.get_config
-original_BlockContext_init = gr.blocks.BlockContext.__init__
+def Blocks_get_config_file(self, *args, **kwargs):
+ config = original_Blocks_get_config_file(self, *args, **kwargs)
+
+ for comp_config in config["components"]:
+ if "example_inputs" in comp_config:
+ comp_config["example_inputs"] = {"serialized": []}
+
+ return config
+
+
+original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
+original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
+original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
+original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
-gr.components.IOComponent.__init__ = IOComponent_init
-gr.blocks.Block.get_config = Block_get_config
-gr.blocks.BlockContext.__init__ = BlockContext_init
ui_tempdir.install_ui_tempdir_override()
diff --git a/modules/images.py b/modules/images.py
index 019c1d60..eb644733 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -355,7 +355,9 @@ class FilenameGenerator:
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
- 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
+ 'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
+ 'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
+ 'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
'prompt': lambda self: sanitize_filename_part(self.prompt),
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
@@ -368,7 +370,8 @@ class FilenameGenerator:
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
- 'none': lambda self: '', # Overrides the default so you can get just the sequence number
+ 'none': lambda self: '', # Overrides the default, so you can get just the sequence number
+ 'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
}
default_time_format = '%Y%m%d%H%M%S'
@@ -448,6 +451,14 @@ class FilenameGenerator:
return sanitize_filename_part(formatted_time, replace_spaces=False)
+ def image_hash(self, *args):
+ length = int(args[0]) if (args and args[0] != "") else None
+ return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
+
+ def string_hash(self, text, *args):
+ length = int(args[0]) if (args and args[0] != "") else 8
+ return hashlib.sha256(text.encode()).hexdigest()[0:length]
+
def apply(self, x):
res = ''
@@ -589,6 +600,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
namegen = FilenameGenerator(p, seed, prompt, image)
+ # WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
+ if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
+ print('Image dimensions too large; saving as PNG')
+ extension = ".png"
+
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
diff --git a/modules/img2img.py b/modules/img2img.py
index ac9fd3f8..1519e132 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -122,15 +122,14 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
is_batch = mode == 5
if mode == 0: # img2img
- image = init_img.convert("RGB")
+ image = init_img
mask = None
elif mode == 1: # img2img sketch
- image = sketch.convert("RGB")
+ image = sketch
mask = None
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
- mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
- image = image.convert("RGB")
+ mask = processing.create_binary_mask(mask)
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
orig = inpaint_color_sketch_orig or inpaint_color_sketch
@@ -139,7 +138,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
- image = image.convert("RGB")
elif mode == 4: # inpaint upload mask
image = init_img_inpaint
mask = init_mask_inpaint
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index d8370576..2894eee4 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -132,10 +132,29 @@ def get_gradio_auth_creds():
yield cred
+def dumpstacks():
+ import threading
+ import traceback
+
+ id2name = {th.ident: th.name for th in threading.enumerate()}
+ code = []
+ for threadId, stack in sys._current_frames().items():
+ code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
+ for filename, lineno, name, line in traceback.extract_stack(stack):
+ code.append(f"""File: "{filename}", line {lineno}, in {name}""")
+ if line:
+ code.append(" " + line.strip())
+
+ print("\n".join(code))
+
+
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
+
+ dumpstacks()
+
os._exit(0)
if not os.environ.get("COVERAGE_RUN"):
diff --git a/modules/interrogate.py b/modules/interrogate.py
index a3ae1dd5..3045560d 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -186,9 +186,8 @@ class InterrogateModels:
res = ""
shared.state.begin(job="interrogate")
try:
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- devices.torch_gc()
+ lowvram.send_everything_to_cpu()
+ devices.torch_gc()
self.load()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 449a8755..7e4d5a61 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -246,7 +246,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
- if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions:
+ if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
return []
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
@@ -321,7 +321,7 @@ def prepare_environment():
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
- stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
+ stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 96f52b7b..45701046 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,5 +1,5 @@
import torch
-from modules import devices
+from modules import devices, shared
module_in_gpu = None
cpu = torch.device("cpu")
@@ -14,6 +14,20 @@ def send_everything_to_cpu():
module_in_gpu = None
+def is_needed(sd_model):
+ return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
+
+
+def apply(sd_model):
+ enable = is_needed(sd_model)
+ shared.parallel_processing_allowed = not enable
+
+ if enable:
+ setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
+ else:
+ sd_model.lowvram = False
+
+
def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, 'lowvram', False):
return
@@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
def is_enabled(sd_model):
- return getattr(sd_model, 'lowvram', False)
+ return sd_model.lowvram
diff --git a/modules/options.py b/modules/options.py
index db1fb157..758b1ce5 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -8,7 +8,7 @@ from modules.shared_cmd_options import cmd_opts
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
self.default = default
self.label = label
self.component = component
@@ -26,6 +26,9 @@ class OptionInfo:
self.infotext = infotext
+ self.restrict_api = restrict_api
+ """If True, the setting will not be accessible via API"""
+
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
@@ -71,7 +74,7 @@ options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
class Options:
typemap = {int: float}
- def __init__(self, data_labels, restricted_opts):
+ def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()}
self.restricted_opts = restricted_opts
@@ -113,14 +116,18 @@ class Options:
return super(Options, self).__getattribute__(item)
- def set(self, key, value):
+ def set(self, key, value, is_api=False, run_callbacks=True):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
- if self.data_labels[key].do_not_save:
+ option = self.data_labels[key]
+ if option.do_not_save:
+ return False
+
+ if is_api and option.restrict_api:
return False
try:
@@ -128,9 +135,9 @@ class Options:
except RuntimeError:
return False
- if self.data_labels[key].onchange is not None:
+ if run_callbacks and option.onchange is not None:
try:
- self.data_labels[key].onchange()
+ option.onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
diff --git a/modules/patches.py b/modules/patches.py
new file mode 100644
index 00000000..348235e7
--- /dev/null
+++ b/modules/patches.py
@@ -0,0 +1,64 @@
+from collections import defaultdict
+
+
+def patch(key, obj, field, replacement):
+ """Replaces a function in a module or a class.
+
+ Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
+ If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+ replacement: the new function
+
+ Returns:
+ the original function
+ """
+
+ patch_key = (obj, field)
+ if patch_key in originals[key]:
+ raise RuntimeError(f"patch for {field} is already applied")
+
+ original_func = getattr(obj, field)
+ originals[key][patch_key] = original_func
+
+ setattr(obj, field, replacement)
+
+ return original_func
+
+
+def undo(key, obj, field):
+ """Undoes the peplacement by the patch().
+
+ If the function is not replaced, raises an exception.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+
+ Returns:
+ Always None
+ """
+
+ patch_key = (obj, field)
+
+ if patch_key not in originals[key]:
+ raise RuntimeError(f"there is no patch for {field} to undo")
+
+ original_func = originals[key].pop(patch_key)
+ setattr(obj, field, original_func)
+
+ return None
+
+
+def original(key, obj, field):
+ """Returns the original function for the patch created by the patch() function"""
+ patch_key = (obj, field)
+
+ return originals[key].get(patch_key, None)
+
+
+originals = defaultdict(dict)
diff --git a/modules/processing.py b/modules/processing.py
index 75f1d66f..066351c1 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -81,6 +81,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def create_binary_mask(image):
+ if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.convert('L')
+ return image
def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
@@ -194,6 +200,8 @@ class StableDiffusionProcessing:
sd_vae_name: str = field(default=None, init=False)
sd_vae_hash: str = field(default=None, init=False)
+ is_api: bool = field(default=False, init=False)
+
def __post_init__(self):
if self.sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -258,7 +266,7 @@ class StableDiffusionProcessing:
def setup_scripts(self):
self.scripts_setup_complete = True
- self.scripts.setup_scrips(self)
+ self.scripts.setup_scrips(self, is_ui=not self.is_api)
def comment(self, text):
self.comments[text] = 1
@@ -378,15 +386,20 @@ class StableDiffusionProcessing:
return self.token_merging_ratio or opts.token_merging_ratio
def setup_prompts(self):
- if type(self.prompt) == list:
+ if isinstance(self.prompt,list):
self.all_prompts = self.prompt
+ elif isinstance(self.negative_prompt, list):
+ self.all_prompts = [self.prompt] * len(self.negative_prompt)
else:
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
- if type(self.negative_prompt) == list:
+ if isinstance(self.negative_prompt, list):
self.all_negative_prompts = self.negative_prompt
else:
- self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
+ self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
+
+ if len(self.all_prompts) != len(self.all_negative_prompts):
+ raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
@@ -503,10 +516,10 @@ class Processed:
self.s_noise = p.s_noise
self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
- self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
- self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]
+ self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]
+ self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1
+ self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
@@ -693,17 +706,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
- # after running refiner, the refiner model is not unloaded - webui swaps back to main model here
- if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
- sd_models.reload_model_weights()
-
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
+ # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
for k, v in p.override_settings.items():
- setattr(opts, k, v)
+ opts.set(k, v, is_api=True, run_callbacks=False)
if k == 'sd_model_checkpoint':
sd_models.reload_model_weights()
@@ -732,7 +742,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
- if type(p.prompt) == list:
+ if isinstance(p.prompt, list):
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None
@@ -748,7 +758,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.tiling is None:
p.tiling = opts.tiling
- if p.refiner_checkpoint not in (None, "", "None"):
+ if p.refiner_checkpoint not in (None, "", "None", "none"):
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
if p.refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
@@ -763,12 +773,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_prompts()
- if type(seed) == list:
+ if isinstance(seed, list):
p.all_seeds = seed
else:
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
- if type(subseed) == list:
+ if isinstance(subseed, list):
p.all_subseeds = subseed
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
@@ -1146,6 +1156,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
devices.torch_gc()
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
+ if shared.state.interrupted:
+ return samples
+
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
@@ -1259,12 +1272,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_negative_prompt == '':
self.hr_negative_prompt = self.negative_prompt
- if type(self.hr_prompt) == list:
+ if isinstance(self.hr_prompt, list):
self.all_hr_prompts = self.hr_prompt
else:
self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
- if type(self.hr_negative_prompt) == list:
+ if isinstance(self.hr_negative_prompt, list):
self.all_hr_negative_prompts = self.hr_negative_prompt
else:
self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
@@ -1382,7 +1395,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = self.image_mask
if image_mask is not None:
- image_mask = image_mask.convert('L')
+ # image_mask is passed in as RGBA by Gradio to support alpha masks,
+ # but we still want to support binary masks.
+ image_mask = create_binary_mask(image_mask)
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
@@ -1501,7 +1516,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 3c5b37d2..29ccb78f 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -5,7 +5,7 @@ from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
-class ScriptRefiner(scripts.Script):
+class ScriptRefiner(scripts.ScriptBuiltinUI):
section = "accordions"
create_group = False
@@ -42,7 +42,7 @@ class ScriptRefiner(scripts.Script):
# the actual implementation is in sd_samplers_common.py, apply_refiner
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
- p.refiner_checkpoint_info = None
+ p.refiner_checkpoint = None
p.refiner_switch_at = None
else:
p.refiner_checkpoint = refiner_checkpoint
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index 6ce3b2fc..6b6ff987 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -7,7 +7,7 @@ from modules.shared import cmd_opts
from modules.ui_components import ToolButton
-class ScriptSeed(scripts.ScriptBuiltin):
+class ScriptSeed(scripts.ScriptBuiltinUI):
section = "seed"
create_group = False
diff --git a/modules/progress.py b/modules/progress.py
index f405f07f..69921de7 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -48,6 +48,7 @@ def add_task_to_queue(id_job):
class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+ live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
class ProgressResponse(BaseModel):
@@ -71,7 +72,12 @@ def progressapi(req: ProgressRequest):
completed = req.id_task in finished_tasks
if not active:
- return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
+ textinfo = "Waiting..."
+ if queued:
+ sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
+ queue_index = sorted_queued.index(req.id_task)
+ textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
+ return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
progress = 0
@@ -89,31 +95,30 @@ def progressapi(req: ProgressRequest):
predicted_duration = elapsed_since_start / progress if progress > 0 else None
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
+ live_preview = None
id_live_preview = req.id_live_preview
- shared.state.set_current_image()
- if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
- image = shared.state.current_image
- if image is not None:
- buffered = io.BytesIO()
-
- if opts.live_previews_image_format == "png":
- # using optimize for large images takes an enormous amount of time
- if max(*image.size) <= 256:
- save_kwargs = {"optimize": True}
+
+ if opts.live_previews_enable and req.live_preview:
+ shared.state.set_current_image()
+ if shared.state.id_live_preview != req.id_live_preview:
+ image = shared.state.current_image
+ if image is not None:
+ buffered = io.BytesIO()
+
+ if opts.live_previews_image_format == "png":
+ # using optimize for large images takes an enormous amount of time
+ if max(*image.size) <= 256:
+ save_kwargs = {"optimize": True}
+ else:
+ save_kwargs = {"optimize": False, "compress_level": 1}
+
else:
- save_kwargs = {"optimize": False, "compress_level": 1}
-
- else:
- save_kwargs = {}
-
- image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
- base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
- live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
- id_live_preview = shared.state.id_live_preview
- else:
- live_preview = None
- else:
- live_preview = None
+ save_kwargs = {}
+
+ image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
+ base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
+ live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
+ id_live_preview = shared.state.id_live_preview
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index e8c41f38..334efeef 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -107,7 +107,7 @@ def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=N
yield args[(step - 1) % len(args)]
def start(self, args):
def flatten(x):
- if type(x) == str:
+ if isinstance(x, str):
yield x
else:
for gen in x:
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 0700b853..02841c30 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -55,6 +55,7 @@ class UpscalerRealESRGAN(Upscaler):
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
+ device=self.device,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
diff --git a/modules/rng.py b/modules/rng.py
index f927a318..9e8ba2ee 100644
--- a/modules/rng.py
+++ b/modules/rng.py
@@ -98,7 +98,7 @@ def slerp(val, low, high):
class ImageRNG:
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
- self.shape = shape
+ self.shape = tuple(map(int, shape))
self.seeds = seeds
self.subseeds = subseeds
self.subseed_strength = subseed_strength
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 77ee55ee..fab23551 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -28,6 +28,15 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+class ExtraNoiseParams:
+ def __init__(self, noise, x):
+ self.noise = noise
+ """Random noise generated by the seed"""
+
+ self.x = x
+ """Latent image representation of the image"""
+
+
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
@@ -100,6 +109,7 @@ callback_map = dict(
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
+ callbacks_extra_noise=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
callbacks_cfg_after_cfg=[],
@@ -189,6 +199,14 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback')
+def extra_noise_callback(params: ExtraNoiseParams):
+ for c in callback_map['callbacks_extra_noise']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'callbacks_extra_noise')
+
+
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']:
try:
@@ -367,6 +385,14 @@ def on_image_saved(callback):
add_callback(callback_map['callbacks_image_saved'], callback)
+def on_extra_noise(callback):
+ """register a function to be called before adding extra noise in img2img or hires fix;
+ The callback is called with one argument:
+ - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
+ """
+ add_callback(callback_map['callbacks_extra_noise'], callback)
+
+
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
diff --git a/modules/scripts.py b/modules/scripts.py
index cbdac2b5..e8518ad0 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -68,6 +68,9 @@ class Script:
on_after_component_elem_id = None
"""list of callbacks to be called after a component with an elem_id is created"""
+ setup_for_ui_only = False
+ """If true, the script setup will only be run in Gradio UI, not in API"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -258,7 +261,6 @@ class Script:
self.on_after_component_elem_id.append((elem_id, callback))
-
def describe(self):
"""unused"""
return ""
@@ -267,7 +269,7 @@ class Script:
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
- tabkind = 'img2img' if self.is_img2img else 'txt2txt'
+ tabkind = 'img2img' if self.is_img2img else 'txt2img'
tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
@@ -280,13 +282,14 @@ class Script:
pass
-class ScriptBuiltin(Script):
+class ScriptBuiltinUI(Script):
+ setup_for_ui_only = True
def elem_id(self, item_id):
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
- tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+ tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
return f'{tabname}{item_id}'
@@ -728,8 +731,11 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
- def setup_scrips(self, p):
+ def setup_scrips(self, p, *, is_ui=True):
for script in self.alwayson_scripts:
+ if not is_ui and script.setup_for_ui_only:
+ continue
+
try:
script_args = p.script_args[script.args_from:script.args_to]
script.setup(p, *script_args)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 695c5736..8863107a 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
```
"""
- def __init__(self, state_dict, device):
+ def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__()
self.state_dict = state_dict
self.device = device
+ self.weight_dtype_conversion = weight_dtype_conversion or {}
+ self.default_dtype = self.weight_dtype_conversion.get('')
+
+ def get_weight_dtype(self, key):
+ key_first_term, _ = key.split('.', 1)
+ return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -167,23 +173,60 @@ class LoadStateDictOnMeta(ReplaceHelper):
sd = self.state_dict
device = self.device
- def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
- params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
+ def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
+ used_param_keys = []
- for name, param in params:
- if param.is_meta:
- self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
+ for name, param in module._parameters.items():
+ if param is None:
+ continue
- original(self, state_dict, prefix, *args, **kwargs)
+ key = prefix + name
+ sd_param = sd.pop(key, None)
+ if sd_param is not None:
+ state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
+ used_param_keys.append(key)
- for name, _ in params:
+ if param.is_meta:
+ dtype = sd_param.dtype if sd_param is not None else param.dtype
+ module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
+
+ for name in module._buffers:
key = prefix + name
- if key in sd:
- del sd[key]
+ sd_param = sd.pop(key, None)
+ if sd_param is not None:
+ state_dict[key] = sd_param
+ used_param_keys.append(key)
+
+ original(module, state_dict, prefix, *args, **kwargs)
+
+ for key in used_param_keys:
+ state_dict.pop(key, None)
+
+ def load_state_dict(original, module, state_dict, strict=True):
+ """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
+ because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
+ all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
+
+ In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
+
+ The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
+ the function and does not call the original) the state dict will just fail to load because weights
+ would be on the meta device.
+ """
+
+ if state_dict == sd:
+ state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
+
+ original(module, state_dict, strict=strict)
+
+ module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
+ module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
+ layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
+ group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 46652fbd..592f0055 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -245,7 +245,21 @@ class StableDiffusionModelHijack:
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
- if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
+ embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
+ conditioner.embedders[i] = embedder.wrapped
+ if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
+ embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
+ conditioner.embedders[i] = embedder.wrapped
+
+ if hasattr(m, 'cond_stage_model'):
+ delattr(m, 'cond_stage_model')
+
+ elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f6fbdcd6..547e93c4 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -343,7 +343,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
- if not shared.cmd_opts.no_half:
+ if shared.cmd_opts.no_half:
+ model.float()
+ devices.dtype_unet = torch.float32
+ timer.record("apply float()")
+ else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
@@ -359,9 +363,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if depth_model:
model.depth_model = depth_model
+ devices.dtype_unet = torch.float16
timer.record("apply half()")
- devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -482,8 +486,12 @@ class SdModelData:
return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae.base_vae = getattr(v, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
@@ -510,7 +518,7 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
@@ -518,10 +526,17 @@ def send_model_to_cpu(m):
devices.torch_gc()
-def send_model_to_device(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+def model_target_device(m):
+ if lowvram.is_needed(m):
+ return devices.cpu
else:
+ return devices.device
+
+
+def send_model_to_device(m):
+ lowvram.apply(m)
+
+ if not m.lowvram:
m.to(shared.device)
@@ -579,7 +594,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("create model")
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
+ else:
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
@@ -642,13 +665,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
send_model_to_device(already_loaded)
timer.record("send model to device")
- model_data.set_sd_model(already_loaded)
+ model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
@@ -660,6 +684,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
@@ -716,7 +744,7 @@ def reload_model_weights(sd_model=None, info=None):
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py
new file mode 100644
index 00000000..5ffd2f4f
--- /dev/null
+++ b/modules/sd_models_types.py
@@ -0,0 +1,31 @@
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+ from modules.sd_models import CheckpointInfo
+
+
+class WebuiSdModel(LatentDiffusion):
+ """This class is not actually instantinated, but its fields are created and fieeld by webui"""
+
+ lowvram: bool
+ """True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
+
+ sd_model_hash: str
+ """short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
+
+ sd_model_checkpoint: str
+ """path to the file on disk that model weights were obtained from"""
+
+ sd_checkpoint_info: 'CheckpointInfo'
+ """structure with additional information about the file with model's weights"""
+
+ is_sdxl: bool
+ """True if the model's architecture is SDXL"""
+
+ is_sd2: bool
+ """True if the model's architecture is SD 2.x"""
+
+ is_sd1: bool
+ """True if the model's architecture is SD 1.x"""
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index bc9b97e4..b8101d38 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -165,7 +165,7 @@ class CFGDenoiser(torch.nn.Module):
else:
cond_in = catenate_conds([tensor, uncond])
- if shared.batch_cond_uncond:
+ if shared.opts.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
@@ -175,7 +175,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 07fc4434..60fa161c 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -35,22 +35,27 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD":
def samples_to_images_tensor(sample, approximation=None, model=None):
- '''latents -> images [-1, 1]'''
- if approximation is None:
+ """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
+
+ if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
+ from modules import lowvram
+ if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
+ approximation = 1
+
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
- x_sample = sample * 1.5
- x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
+ x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else:
if model is None:
model = shared.sd_model
- x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
+ with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
+ x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
@@ -217,6 +222,7 @@ class Sampler:
self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'
+ self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key
@@ -273,7 +279,7 @@ class Sampler:
extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
+ if self.eta != self.eta_default:
p.extra_generation_params[self.eta_infotext_field] = self.eta
extra_params_kwargs['eta'] = self.eta
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 67853ff1..b9e0d577 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -3,6 +3,7 @@ import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
+from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
@@ -16,8 +17,8 @@ samplers_k_diffusion = [
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
@@ -34,7 +35,7 @@ samplers_k_diffusion = [
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
- ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
+ ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
]
@@ -145,6 +146,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
xi = x + noise * sigma_sched[0]
+ if opts.img2img_extra_noise > 0:
+ p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
+ extra_noise_params = ExtraNoiseParams(noise, x)
+ extra_noise_callback(extra_noise_params)
+ noise = extra_noise_params.noise
+ xi += noise * opts.img2img_extra_noise
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index c1f534ed..7a6cbd46 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -3,6 +3,7 @@ import inspect
import sys
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
+from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
@@ -76,6 +77,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_option_field = 'eta_ddim'
self.eta_infotext_field = 'Eta DDIM'
+ self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
@@ -103,6 +105,13 @@ class CompVisSampler(sd_samplers_common.Sampler):
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
+ if opts.img2img_extra_noise > 0:
+ p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
+ extra_noise_params = ExtraNoiseParams(noise, x)
+ extra_noise_callback(extra_noise_params)
+ noise = extra_noise_params.noise
+ xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 6d708ad2..5525cfbc 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -47,7 +47,7 @@ def apply_unet(option=None):
if current_unet_option is None:
current_unet = None
- if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device)
return
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index fd9a1c2a..669097da 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -70,7 +70,6 @@ def get_filename(filepath):
def refresh_vae_list():
- global vae_dict
vae_dict.clear()
paths = [
@@ -104,7 +103,7 @@ def refresh_vae_list():
name = get_filename(filepath)
vae_dict[name] = filepath
- vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
+ vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
def find_vae_near_checkpoint(checkpoint_file):
@@ -160,7 +159,7 @@ def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
- if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
+ if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic):
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
return VaeResolution(resolved=False)
@@ -193,7 +192,7 @@ def load_vae_dict(filename, map_location):
def load_vae(model, vae_file=None, vae_source="from unknown source"):
- global vae_dict, loaded_vae_file
+ global vae_dict, base_vae, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@@ -231,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
restore_base_vae(model)
loaded_vae_file = vae_file
+ model.base_vae = base_vae
+ model.loaded_vae_file = loaded_vae_file
# don't call this from outside
@@ -262,7 +263,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
if loaded_vae_file == vae_file:
return
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if sd_model.lowvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
@@ -274,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
print("VAE weights loaded.")
diff --git a/modules/shared.py b/modules/shared.py
index d9d01484..63661939 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -2,16 +2,15 @@ import sys
import gradio as gr
-from modules import shared_cmd_options, shared_gradio_themes, options, shared_items
+from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
-from ldm.models.diffusion.ddpm import LatentDiffusion
from modules import util
cmd_opts = shared_cmd_options.cmd_opts
parser = shared_cmd_options.parser
-batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
-parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
+batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
+parallel_processing_allowed = True
styles_filename = cmd_opts.styles_file
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
@@ -40,7 +39,7 @@ options_templates = None
opts = None
restricted_opts = None
-sd_model: LatentDiffusion = None
+sd_model: sd_models_types.WebuiSdModel = None
settings_components = None
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py
index 485e89d5..822db0a9 100644
--- a/modules/shared_gradio_themes.py
+++ b/modules/shared_gradio_themes.py
@@ -36,7 +36,8 @@ gradio_hf_hub_themes = [
"step-3-profit/Midnight-Deep",
"Taithrah/Minimal",
"ysharma/huggingface",
- "ysharma/steampunk"
+ "ysharma/steampunk",
+ "NoCrypt/miku"
]
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 69d9d70a..d1389838 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -111,6 +111,12 @@ options_templates.update(options_section(('system', "System"), {
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
}))
+options_templates.update(options_section(('API', "API"), {
+ "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
+ "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
+ "api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
+}))
+
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
@@ -166,7 +172,8 @@ For img2img, VAE is used to process user's input image before the sampling, and
options_templates.update(options_section(('img2img', "img2img"), {
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}, infotext='Noise multiplier'),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
+ "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
@@ -185,7 +192,8 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
- "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
+ "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
+ "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -232,6 +240,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
+ "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
@@ -279,13 +288,15 @@ options_templates.update(options_section(('ui', "Live previews"), {
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
+ "live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
+ "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
- "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
+ "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
diff --git a/modules/shared_state.py b/modules/shared_state.py
index 3dc9c788..d272ee5b 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -128,7 +128,7 @@ class State:
devices.torch_gc()
def set_current_image(self):
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+ """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
if not shared.parallel_processing_allowed:
return
diff --git a/modules/ui.py b/modules/ui.py
index a6b1f964..2b6a13cb 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,7 +13,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad
from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
-from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
from modules.ui_common import create_refresh_button
from modules.ui_gradio_extensions import reload_javascript
@@ -333,7 +333,7 @@ def create_ui():
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
extra_tabs.__enter__()
- with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False):
+ with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
scripts.scripts_txt2img.prepare_ui()
@@ -549,7 +549,7 @@ def create_ui():
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
extra_tabs.__enter__()
- with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow(equal_height=False):
+ with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
copy_image_destinations = {}
@@ -575,7 +575,7 @@ def create_ui():
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
@@ -583,7 +583,7 @@ def create_ui():
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
@@ -598,7 +598,7 @@ def create_ui():
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 4c035f2a..eddc4bc8 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -132,7 +132,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
generation_info = None
with gr.Column():
diff --git a/modules/ui_components.py b/modules/ui_components.py
index d08b2b99..55979f62 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -20,6 +20,18 @@ class ToolButton(FormComponent, gr.Button):
return "button"
+class ResizeHandleRow(gr.Row):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.elem_classes.append("resize-handle-row")
+
+ def get_block_name(self):
+ return "row"
+
+
class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 15a8b0bf..e0138267 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -65,7 +65,7 @@ def save_config_state(name):
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f:
- json.dump(current_config_state, f)
+ json.dump(current_config_state, f, indent=4)
config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys())
@@ -200,119 +200,129 @@ def update_config_states_table(state_name):
created_date = time.asctime(time.gmtime(config_state["created_at"]))
filepath = config_state.get("filepath", "<unknown>")
- code = f"""<!-- {time.time()} -->"""
-
- webui_remote = config_state["webui"]["remote"] or ""
- webui_branch = config_state["webui"]["branch"]
- webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
- webui_commit_date = config_state["webui"]["commit_date"]
- if webui_commit_date:
- webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
- else:
- webui_commit_date = "<unknown>"
-
- remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
- commit_link = make_commit_link(webui_commit_hash, webui_remote)
- date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
-
- current_webui = config_states.get_webui_config()
-
- style_remote = ""
- style_branch = ""
- style_commit = ""
- if current_webui["remote"] != webui_remote:
- style_remote = STYLE_PRIMARY
- if current_webui["branch"] != webui_branch:
- style_branch = STYLE_PRIMARY
- if current_webui["commit_hash"] != webui_commit_hash:
- style_commit = STYLE_PRIMARY
-
- code += f"""<h2>Config Backup: {config_name}</h2>
- <div><b>Filepath:</b> {filepath}</div>
- <div><b>Created at:</b> {created_date}</div>"""
-
- code += f"""<h2>WebUI State</h2>
- <table id="config_state_webui">
- <thead>
- <tr>
- <th>URL</th>
- <th>Branch</th>
- <th>Commit</th>
- <th>Date</th>
- </tr>
- </thead>
- <tbody>
- <tr>
- <td><label{style_remote}>{remote}</label></td>
- <td><label{style_branch}>{webui_branch}</label></td>
- <td><label{style_commit}>{commit_link}</label></td>
- <td><label{style_commit}>{date_link}</label></td>
- </tr>
- </tbody>
- </table>
- """
-
- code += """<h2>Extension State</h2>
- <table id="config_state_extensions">
- <thead>
- <tr>
- <th>Extension</th>
- <th>URL</th>
- <th>Branch</th>
- <th>Commit</th>
- <th>Date</th>
- </tr>
- </thead>
- <tbody>
- """
-
- ext_map = {ext.name: ext for ext in extensions.extensions}
-
- for ext_name, ext_conf in config_state["extensions"].items():
- ext_remote = ext_conf["remote"] or ""
- ext_branch = ext_conf["branch"] or "<unknown>"
- ext_enabled = ext_conf["enabled"]
- ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
- ext_commit_date = ext_conf["commit_date"]
- if ext_commit_date:
- ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ try:
+ webui_remote = config_state["webui"]["remote"] or ""
+ webui_branch = config_state["webui"]["branch"]
+ webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
+ webui_commit_date = config_state["webui"]["commit_date"]
+ if webui_commit_date:
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
else:
- ext_commit_date = "<unknown>"
+ webui_commit_date = "<unknown>"
- remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
- commit_link = make_commit_link(ext_commit_hash, ext_remote)
- date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+ remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
+
+ current_webui = config_states.get_webui_config()
- style_enabled = ""
style_remote = ""
style_branch = ""
style_commit = ""
- if ext_name in ext_map:
- current_ext = ext_map[ext_name]
- current_ext.read_info_from_repo()
- if current_ext.enabled != ext_enabled:
- style_enabled = STYLE_PRIMARY
- if current_ext.remote != ext_remote:
- style_remote = STYLE_PRIMARY
- if current_ext.branch != ext_branch:
- style_branch = STYLE_PRIMARY
- if current_ext.commit_hash != ext_commit_hash:
- style_commit = STYLE_PRIMARY
-
- code += f"""
- <tr>
- <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
- <td><label{style_remote}>{remote}</label></td>
- <td><label{style_branch}>{ext_branch}</label></td>
- <td><label{style_commit}>{commit_link}</label></td>
- <td><label{style_commit}>{date_link}</label></td>
- </tr>
- """
-
- code += """
- </tbody>
- </table>
- """
+ if current_webui["remote"] != webui_remote:
+ style_remote = STYLE_PRIMARY
+ if current_webui["branch"] != webui_branch:
+ style_branch = STYLE_PRIMARY
+ if current_webui["commit_hash"] != webui_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code = f"""<!-- {time.time()} -->
+<h2>Config Backup: {config_name}</h2>
+<div><b>Filepath:</b> {filepath}</div>
+<div><b>Created at:</b> {created_date}</div>
+<h2>WebUI State</h2>
+<table id="config_state_webui">
+ <thead>
+ <tr>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td>
+ <label{style_remote}>{remote}</label>
+ </td>
+ <td>
+ <label{style_branch}>{webui_branch}</label>
+ </td>
+ <td>
+ <label{style_commit}>{commit_link}</label>
+ </td>
+ <td>
+ <label{style_commit}>{date_link}</label>
+ </td>
+ </tr>
+ </tbody>
+</table>
+<h2>Extension State</h2>
+<table id="config_state_extensions">
+ <thead>
+ <tr>
+ <th>Extension</th>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+"""
+
+ ext_map = {ext.name: ext for ext in extensions.extensions}
+
+ for ext_name, ext_conf in config_state["extensions"].items():
+ ext_remote = ext_conf["remote"] or ""
+ ext_branch = ext_conf["branch"] or "<unknown>"
+ ext_enabled = ext_conf["enabled"]
+ ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
+ ext_commit_date = ext_conf["commit_date"]
+ if ext_commit_date:
+ ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ else:
+ ext_commit_date = "<unknown>"
+
+ remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
+ commit_link = make_commit_link(ext_commit_hash, ext_remote)
+ date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+
+ style_enabled = ""
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if ext_name in ext_map:
+ current_ext = ext_map[ext_name]
+ current_ext.read_info_from_repo()
+ if current_ext.enabled != ext_enabled:
+ style_enabled = STYLE_PRIMARY
+ if current_ext.remote != ext_remote:
+ style_remote = STYLE_PRIMARY
+ if current_ext.branch != ext_branch:
+ style_branch = STYLE_PRIMARY
+ if current_ext.commit_hash != ext_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f""" <tr>
+ <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
+ <td><label{style_remote}>{remote}</label></td>
+ <td><label{style_branch}>{ext_branch}</label></td>
+ <td><label{style_commit}>{commit_link}</label></td>
+ <td><label{style_commit}>{date_link}</label></td>
+ </tr>
+"""
+
+ code += """ </tbody>
+</table>"""
+
+ except Exception as e:
+ print(f"[ERROR]: Config states {filepath}, {e}")
+ code = f"""<!-- {time.time()} -->
+<h2>Config Backup: {config_name}</h2>
+<div><b>Filepath:</b> {filepath}</div>
+<div><b>Created at:</b> {created_date}</div>
+<h2>This file is corrupted</h2>"""
return code
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index ebb5249f..ca6c2607 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -30,7 +30,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
- for index, name in enumerate(sd_models.checkpoints_list):
+ names = list(sd_models.checkpoints_list)
+ for index, name in enumerate(names):
yield self.create_item(name, index)
def allowed_directories_for_previews(self):
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 506017e5..85015db5 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -44,6 +44,8 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
+ else:
+ os.makedirs(dir, exist_ok=True)
use_metadata = False
metadata = PngImagePlugin.PngInfo()
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index da0e48aa..daaf761f 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -86,6 +86,15 @@ def confirm_checkpoints(p, xs):
raise RuntimeError(f"Unknown checkpoint: {x}")
+def confirm_checkpoints_or_none(p, xs):
+ for x in xs:
+ if x in (None, "", "None", "none"):
+ continue
+
+ if modules.sd_models.get_closet_checkpoint_match(x) is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
+
+
def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
@@ -191,6 +200,10 @@ def list_to_csv_string(data_list):
return o.getvalue().strip()
+def csv_string_to_list_strip(data_str):
+ return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str)))))
+
+
class AxisOption:
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
self.label = label
@@ -241,6 +254,8 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
+ AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")),
+ AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
@@ -250,6 +265,9 @@ axis_options = [
AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),
+ AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
+ AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
+ AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
]
@@ -425,7 +443,6 @@ class Script(scripts.Script):
with gr.Column():
csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
-
with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
@@ -459,19 +476,19 @@ class Script(scripts.Script):
choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
- current_values = axis_values
- current_dropdown_values = axis_values_dropdown
if has_choices:
choices = choices()
if csv_mode:
- current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
- current_values = list_to_csv_string(current_dropdown_values)
+ if axis_values_dropdown:
+ axis_values = list_to_csv_string(list(filter(lambda x: x in choices, axis_values_dropdown)))
+ axis_values_dropdown = []
else:
- current_dropdown_values = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(axis_values)))]
- current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
+ if axis_values:
+ axis_values_dropdown = list(filter(lambda x: x in choices, csv_string_to_list_strip(axis_values)))
+ axis_values = ""
- return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=current_values),
- gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=current_dropdown_values))
+ return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=axis_values),
+ gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=axis_values_dropdown))
x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])
y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])
@@ -488,7 +505,7 @@ class Script(scripts.Script):
def get_dropdown_update_from_params(axis, params):
val_key = f"{axis} Values"
vals = params.get(val_key, "")
- valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
+ valslist = csv_string_to_list_strip(vals)
return gr.update(value=valslist)
self.infotext_fields = (
@@ -519,7 +536,7 @@ class Script(scripts.Script):
if opt.choices is not None and not csv_mode:
valslist = vals_dropdown
else:
- valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
+ valslist = csv_string_to_list_strip(vals)
if opt.type == int:
valslist_ext = []
diff --git a/style.css b/style.css
index bdf0635a..d67b6336 100644
--- a/style.css
+++ b/style.css
@@ -137,11 +137,16 @@ a{
cursor: pointer;
}
-/* gradio 3.39 puts a lot of overflow: hidden all over the place for an unknown reqasaon. */
-.block.gradio-textbox, div.gradio-group, div.gradio-group div, div.gradio-dropdown{
+/* gradio 3.39 puts a lot of overflow: hidden all over the place for an unknown reason. */
+div.gradio-container, .block.gradio-textbox, div.gradio-group, div.gradio-dropdown{
overflow: visible !important;
}
+/* align-items isn't enough and elements may overflow in Safari. */
+.unequal-height {
+ align-content: flex-start;
+}
+
/* general styled components */
@@ -282,8 +287,8 @@ div.block.gradio-accordion {
}
}
-#txt2img_gallery img, #img2img_gallery img, #extras_gallery img{
- object-fit: scale-down;
+.gradio-gallery .thumbnails img {
+ object-fit: scale-down !important;
}
#txt2img_actions_column, #img2img_actions_column {
gap: 0.5em;
@@ -499,11 +504,15 @@ table.popup-table .link{
/* live preview */
.progressDiv{
- position: relative;
+ position: absolute;
height: 20px;
background: #b4c0cc;
border-radius: 3px !important;
- margin-bottom: -3px;
+ top: -20px;
+}
+
+[id$=_results].mobile{
+ margin-top: 28px;
}
.dark .progressDiv{
@@ -528,19 +537,16 @@ table.popup-table .link{
.livePreview{
position: absolute;
z-index: 300;
- background-color: white;
- margin: -4px;
-}
-
-.dark .livePreview{
- background-color: rgb(17 24 39 / var(--tw-bg-opacity));
+ background: var(--background-fill-primary);
+ width: 100%;
+ height: 100%;
}
.livePreview img{
position: absolute;
object-fit: contain;
width: 100%;
- height: 100%;
+ height: calc(100% - 60px); /* to match gradio's height */
}
/* fullscreen popup (ie in Lora's (i) button) */
@@ -609,13 +615,19 @@ table.popup-table .link{
display: flex;
gap: 1em;
padding: 1em;
- background-color: rgba(0,0,0,0.2);
+ background-color:rgba(0,0,0,0);
+ z-index: 1;
+ transition: 0.2s ease background-color;
+}
+.modalControls:hover {
+ background-color:rgba(0,0,0,0.9);
}
.modalClose {
margin-left: auto;
}
.modalControls span{
color: white;
+ text-shadow: 0px 0px 0.25em black;
font-size: 35px;
font-weight: bold;
cursor: pointer;
@@ -640,6 +652,13 @@ table.popup-table .link{
min-height: 0;
}
+#modalImage{
+ position: absolute;
+ top: 50%;
+ left: 50%;
+ transform: translateX(-50%) translateY(-50%);
+}
+
.modalPrev,
.modalNext {
cursor: pointer;
@@ -844,6 +863,7 @@ footer {
position: absolute;
color: white;
right: 0;
+ z-index: 1
}
.extra-network-cards .card:hover .button-row{
display: flex;
@@ -1034,3 +1054,40 @@ div.accordions > div.input-accordion.input-accordion-open{
flex-flow: column;
}
+
+/* sticky right hand columns */
+
+#img2img_results, #txt2img_results, #extras_results {
+ position: sticky;
+ top: 0.5em;
+}
+
+body.resizing {
+ cursor: col-resize !important;
+}
+
+body.resizing * {
+ pointer-events: none !important;
+}
+
+body.resizing .resize-handle {
+ pointer-events: initial !important;
+}
+
+.resize-handle {
+ position: relative;
+ cursor: col-resize;
+ grid-column: 2 / 3;
+ min-width: 16px !important;
+ max-width: 16px !important;
+ height: 100%;
+}
+
+.resize-handle::after {
+ content: '';
+ position: absolute;
+ top: 0;
+ bottom: 0;
+ left: 7.5px;
+ border-left: 1px dashed var(--border-color-primary);
+}
diff --git a/webui.sh b/webui.sh
index cb8b9d14..3d0f87ee 100755
--- a/webui.sh
+++ b/webui.sh
@@ -141,8 +141,9 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
- export TORCH_COMMAND="pip install --pre torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 --index-url https://download.pytorch.org/whl/nightly/rocm5.5"
- # Navi 3 needs at least 5.5 which is only on the nightly chain
+ export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6"
+ # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5)
+ # so switch to nightly rocm5.6 without explicit versions this time
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"
@@ -245,7 +246,7 @@ while [[ "$KEEP_GOING" -eq "1" ]]; do
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
prepare_tcmalloc
- "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ "${python_cmd}" -u "${LAUNCH_SCRIPT}" "$@"
fi
if [[ ! -f tmp/restart ]]; then