aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.eslintrc.js1
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.yml67
-rw-r--r--README.md7
-rw-r--r--configs/alt-diffusion-m18-inference.yaml73
-rw-r--r--extensions-builtin/Lora/lora_logger.py33
-rw-r--r--extensions-builtin/Lora/lyco_helpers.py47
-rw-r--r--extensions-builtin/Lora/network.py1
-rw-r--r--extensions-builtin/Lora/network_glora.py33
-rw-r--r--extensions-builtin/Lora/network_oft.py97
-rw-r--r--extensions-builtin/Lora/networks.py57
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py7
-rw-r--r--extensions-builtin/mobile/javascript/mobile.js2
-rw-r--r--javascript/dragdrop.js2
-rw-r--r--javascript/edit-attention.js79
-rw-r--r--javascript/extraNetworks.js85
-rw-r--r--javascript/imageviewer.js7
-rw-r--r--javascript/inputAccordion.js81
-rw-r--r--javascript/notification.js6
-rw-r--r--javascript/settings.js46
-rw-r--r--javascript/token-counters.js26
-rw-r--r--javascript/ui.js32
-rw-r--r--modules/api/api.py66
-rw-r--r--modules/api/models.py36
-rw-r--r--modules/cmd_args.py9
-rw-r--r--modules/config_states.py3
-rw-r--r--modules/devices.py3
-rw-r--r--modules/generation_parameters_copypaste.py2
-rw-r--r--modules/gfpgan_model.py25
-rw-r--r--modules/gitpython_hack.py2
-rw-r--r--modules/hypernetworks/hypernetwork.py4
-rw-r--r--modules/images.py19
-rw-r--r--modules/img2img.py32
-rw-r--r--modules/initialize.py4
-rw-r--r--modules/initialize_util.py6
-rw-r--r--modules/launch_utils.py4
-rw-r--r--modules/localization.py21
-rw-r--r--modules/logging_config.py27
-rw-r--r--modules/options.py2
-rw-r--r--modules/paths.py2
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/postprocessing.py2
-rw-r--r--modules/processing.py13
-rw-r--r--modules/processing_scripts/seed.py4
-rw-r--r--modules/prompt_parser.py9
-rw-r--r--modules/restart.py4
-rw-r--r--modules/rng.py2
-rw-r--r--modules/script_callbacks.py6
-rw-r--r--modules/scripts.py6
-rw-r--r--modules/sd_hijack.py39
-rw-r--r--modules/sd_models.py53
-rw-r--r--modules/sd_models_config.py5
-rw-r--r--modules/sd_models_types.py5
-rw-r--r--modules/sd_unet.py4
-rw-r--r--modules/shared_cmd_options.py4
-rw-r--r--modules/shared_items.py6
-rw-r--r--modules/shared_options.py23
-rw-r--r--modules/shared_state.py2
-rw-r--r--modules/sub_quadratic_attention.py4
-rw-r--r--modules/textual_inversion/textual_inversion.py78
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py299
-rw-r--r--modules/ui_common.py15
-rw-r--r--modules/ui_extensions.py2
-rw-r--r--modules/ui_extra_networks.py53
-rw-r--r--modules/ui_extra_networks_checkpoints.py10
-rw-r--r--modules/ui_extra_networks_hypernets.py13
-rw-r--r--modules/ui_extra_networks_textual_inversion.py10
-rw-r--r--modules/ui_gradio_extensions.py6
-rw-r--r--modules/ui_loadsave.py22
-rw-r--r--modules/ui_prompt_styles.py28
-rw-r--r--modules/ui_settings.py59
-rw-r--r--modules/ui_toprow.py141
-rw-r--r--modules/xlmr_m18.py164
-rw-r--r--requirements_versions.txt2
-rw-r--r--script.js33
-rw-r--r--scripts/postprocessing_upscale.py2
-rw-r--r--scripts/prompts_from_file.py32
-rw-r--r--scripts/xyz_grid.py7
-rw-r--r--style.css50
-rw-r--r--webui.bat5
-rw-r--r--webui.py2
-rwxr-xr-xwebui.sh19
82 files changed, 1724 insertions, 580 deletions
diff --git a/.eslintrc.js b/.eslintrc.js
index 4777c276..cf839769 100644
--- a/.eslintrc.js
+++ b/.eslintrc.js
@@ -74,6 +74,7 @@ module.exports = {
create_submit_args: "readonly",
restart_reload: "readonly",
updateInput: "readonly",
+ onEdit: "readonly",
//extraNetworks.js
requestGet: "readonly",
popup: "readonly",
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index cf6a2be8..5876e941 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -1,25 +1,45 @@
name: Bug Report
-description: You think somethings is broken in the UI
+description: You think something is broken in the UI
title: "[Bug]: "
labels: ["bug-report"]
body:
+ - type: markdown
+ attributes:
+ value: |
+ > The title of the bug report should be short and descriptive.
+ > Use relevant keywords for searchability.
+ > Do not leave it blank, but also do not put an entire error log in it.
- type: checkboxes
attributes:
- label: Is there an existing issue for this?
- description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
+ label: Checklist
+ description: |
+ Please perform basic debugging to see if extensions or configuration is the cause of the issue.
+ Basic debug procedure
+  1. Disable all third-party extensions - check if extension is the cause
+  2. Update extensions and webui - sometimes things just need to be updated
+  3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration
+  4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed
+  5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue
+ Before making a issue report please, check that the issue hasn't been reported recently.
options:
- - label: I have searched the existing issues and checked the recent builds/commits
- required: true
+ - label: The issue exists after disabling all extensions
+ - label: The issue exists on a clean installation of webui
+ - label: The issue is caused by an extension, but I believe it is caused by a bug in the webui
+ - label: The issue exists in the current version of the webui
+ - label: The issue has not been reported before recently
+ - label: The issue has been reported before but has not been fixed yet
- type: markdown
attributes:
value: |
- *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
+ > Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible
- type: textarea
id: what-did
attributes:
label: What happened?
description: Tell us what happened in a very clear and simple way
+ placeholder: |
+ txt2img is not working as intended.
validations:
required: true
- type: textarea
@@ -27,9 +47,9 @@ body:
attributes:
label: Steps to reproduce the problem
description: Please provide us with precise step by step instructions on how to reproduce the bug
- value: |
- 1. Go to ....
- 2. Press ....
+ placeholder: |
+ 1. Go to ...
+ 2. Press ...
3. ...
validations:
required: true
@@ -38,13 +58,8 @@ body:
attributes:
label: What should have happened?
description: Tell us what you think the normal behavior should be
- validations:
- required: true
- - type: textarea
- id: sysinfo
- attributes:
- label: Sysinfo
- description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
+ placeholder: |
+ WebUI should ...
validations:
required: true
- type: dropdown
@@ -58,12 +73,25 @@ body:
- Brave
- Apple Safari
- Microsoft Edge
+ - Android
+ - iOS
- Other
- type: textarea
+ id: sysinfo
+ attributes:
+ label: Sysinfo
+ description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
+ placeholder: |
+ 1. Go to WebUI Settings -> Sysinfo -> Download system info.
+ If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file
+ 2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.
+ validations:
+ required: true
+ - type: textarea
id: logs
attributes:
label: Console logs
- description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
+ description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service.
render: Shell
validations:
required: true
@@ -71,4 +99,7 @@ body:
id: misc
attributes:
label: Additional information
- description: Please provide us with any relevant additional info or context.
+ description: |
+ Please provide us with any relevant additional info or context.
+ Examples:
+  I have updated my GPU driver recently.
diff --git a/README.md b/README.md
index 4e083440..6096c4a1 100644
--- a/README.md
+++ b/README.md
@@ -88,9 +88,10 @@ A browser interface based on Gradio library for Stable Diffusion.
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
- Now without any bad letters!
- Load checkpoints in safetensors format
-- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
+- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen
+- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
@@ -103,7 +104,7 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Installation on Windows 10/11 with NVidia-GPUs using release package
-1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
+1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract its contents.
2. Run `update.bat`.
3. Run `run.bat`.
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
@@ -146,7 +147,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
-- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
+- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
- CodeFormer - https://github.com/sczhou/CodeFormer
diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml
new file mode 100644
index 00000000..41a031d5
--- /dev/null
+++ b/configs/alt-diffusion-m18-inference.yaml
@@ -0,0 +1,73 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: modules.xlmr_m18.BertSeriesModelWithTransformation
+ params:
+ name: "XLMR-Large"
diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py
new file mode 100644
index 00000000..d51de297
--- /dev/null
+++ b/extensions-builtin/Lora/lora_logger.py
@@ -0,0 +1,33 @@
+import sys
+import copy
+import logging
+
+
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ "DEBUG": "\033[0;36m", # CYAN
+ "INFO": "\033[0;32m", # GREEN
+ "WARNING": "\033[0;33m", # YELLOW
+ "ERROR": "\033[0;31m", # RED
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
+ "RESET": "\033[0m", # RESET COLOR
+ }
+
+ def format(self, record):
+ colored_record = copy.copy(record)
+ levelname = colored_record.levelname
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+ return super().format(colored_record)
+
+
+logger = logging.getLogger("lora")
+logger.propagate = False
+
+
+if not logger.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(
+ ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
+ )
+ logger.addHandler(handler)
diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py
index 279b34bc..1679a0ce 100644
--- a/extensions-builtin/Lora/lyco_helpers.py
+++ b/extensions-builtin/Lora/lyco_helpers.py
@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
+
+
+# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
+def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
+ '''
+ return a tuple of two value of input dimension decomposed by the number closest to factor
+ second value is higher or equal than first value.
+
+ In LoRA with Kroneckor Product, first value is a value for weight scale.
+ secon value is a value for weight.
+
+ Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
+
+ examples)
+ factor
+ -1 2 4 8 16 ...
+ 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
+ 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
+ 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
+ 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
+ 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
+ 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
+ '''
+
+ if factor > 0 and (dimension % factor) == 0:
+ m = factor
+ n = dimension // factor
+ if m > n:
+ n, m = m, n
+ return m, n
+ if factor < 0:
+ factor = dimension
+ m, n = 1, dimension
+ length = m + n
+ while m<n:
+ new_m = m + 1
+ while dimension%new_m != 0:
+ new_m += 1
+ new_n = dimension // new_m
+ if new_m + new_n > length or new_m>factor:
+ break
+ else:
+ m, n = new_m, new_n
+ if m > n:
+ n, m = m, n
+ return m, n
+
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index d8e8dfb7..6021fd8d 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -93,6 +93,7 @@ class Network: # LoraModule
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
+ self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
new file mode 100644
index 00000000..492d4870
--- /dev/null
+++ b/extensions-builtin/Lora/network_glora.py
@@ -0,0 +1,33 @@
+
+import network
+
+class ModuleTypeGLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
+ return NetworkModuleGLora(net, weights)
+
+ return None
+
+# adapted from https://github.com/KohakuBlueleaf/LyCORIS
+class NetworkModuleGLora(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.w1a = weights.w["a1.weight"]
+ self.w1b = weights.w["b1.weight"]
+ self.w2a = weights.w["a2.weight"]
+ self.w2b = weights.w["b2.weight"]
+
+ def calc_updown(self, orig_weight):
+ w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ output_shape = [w1a.size(0), w1b.size(1)]
+ updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
new file mode 100644
index 00000000..05c37811
--- /dev/null
+++ b/extensions-builtin/Lora/network_oft.py
@@ -0,0 +1,97 @@
+import torch
+import network
+from lyco_helpers import factorization
+from einops import rearrange
+
+
+class ModuleTypeOFT(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
+ return NetworkModuleOFT(net, weights)
+
+ return None
+
+# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
+# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
+class NetworkModuleOFT(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+
+ super().__init__(net, weights)
+
+ self.lin_module = None
+ self.org_module: list[torch.Module] = [self.sd_module]
+
+ # kohya-ss
+ if "oft_blocks" in weights.w.keys():
+ self.is_kohya = True
+ self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
+ self.alpha = weights.w["alpha"] # alpha is constraint
+ self.dim = self.oft_blocks.shape[0] # lora dim
+ # LyCORIS
+ elif "oft_diag" in weights.w.keys():
+ self.is_kohya = False
+ self.oft_blocks = weights.w["oft_diag"]
+ # self.alpha is unused
+ self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
+
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d]
+ is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
+
+ if is_linear:
+ self.out_dim = self.sd_module.out_features
+ elif is_conv:
+ self.out_dim = self.sd_module.out_channels
+ elif is_other_linear:
+ self.out_dim = self.sd_module.embed_dim
+
+ if self.is_kohya:
+ self.constraint = self.alpha * self.out_dim
+ self.num_blocks = self.dim
+ self.block_size = self.out_dim // self.dim
+ else:
+ self.constraint = None
+ self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
+
+ def calc_updown_kb(self, orig_weight, multiplier):
+ oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
+
+ R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
+
+ # This errors out for MultiheadAttention, might need to be handled up-stream
+ merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
+ merged_weight = torch.einsum(
+ 'k n m, k n ... -> k m ...',
+ R,
+ merged_weight
+ )
+ merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
+
+ updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
+ output_shape = orig_weight.shape
+ return self.finalize_updown(updown, orig_weight, output_shape)
+
+ def calc_updown(self, orig_weight):
+ # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
+ multiplier = self.multiplier()
+ return self.calc_updown_kb(orig_weight, multiplier)
+
+ # override to remove the multiplier/scale factor; it's already multiplied in get_weight
+ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
+ if self.bias is not None:
+ updown = updown.reshape(self.bias.shape)
+ updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = updown.reshape(output_shape)
+
+ if len(output_shape) == 4:
+ updown = updown.reshape(output_shape)
+
+ if orig_weight.size().numel() == updown.size().numel():
+ updown = updown.reshape(orig_weight.shape)
+
+ if ex_bias is not None:
+ ex_bias = ex_bias * self.multiplier()
+
+ return updown, ex_bias
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 96f935b2..7f814706 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -5,16 +5,21 @@ import re
import lora_patches
import network
import network_lora
+import network_glora
import network_hada
import network_ia3
import network_lokr
import network_full
import network_norm
+import network_oft
import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+import modules.textual_inversion.textual_inversion as textual_inversion
+
+from lora_logger import logger
module_types = [
network_lora.ModuleTypeLora(),
@@ -23,6 +28,8 @@ module_types = [
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
+ network_oft.ModuleTypeOFT(),
]
@@ -149,9 +156,19 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
+ bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
+ if key_network_without_network_parts == "bundle_emb":
+ emb_name, vec_name = network_part.split(".", 1)
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ if vec_name.split('.')[0] == 'string_to_param':
+ _, k2 = vec_name.split('.', 1)
+ emb_dict['string_to_param'] = {k2: weight}
+ else:
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -174,6 +191,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # kohya_ss OFT module
+ elif sd_module is None and "oft_unet" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ # KohakuBlueLeaf OFT module
+ if sd_module is None and "oft_diag" in key:
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
@@ -195,6 +223,14 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
+ embeddings = {}
+ for emb_name, data in bundle_embeddings.items():
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
+ embedding.loaded = None
+ embeddings[emb_name] = embedding
+
+ net.bundle_embeddings = embeddings
+
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -210,11 +246,15 @@ def purge_networks_from_memory():
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -257,6 +297,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+ logger.warning(
+ f'Skip bundle embedding: "{emb_name}"'
+ ' as it was already loaded from embeddings folder'
+ )
+ continue
+
+ embedding.loaded = False
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ embedding.loaded = True
+ emb_db.register_embedding(embedding, shared.sd_model)
+ else:
+ emb_db.skipped_embeddings[name] = embedding
+
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@@ -418,6 +473,7 @@ def network_forward(module, input, original_forward):
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = ()
self.network_weights_backup = None
+ self.network_bias_backup = None
def network_Linear_forward(self, input):
@@ -564,6 +620,7 @@ extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
+loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 55409a78..df02c663 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name)
+ if lora_on_disk is None:
+ return
path, ext = os.path.splitext(lora_on_disk.filename)
@@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item
def list_items(self):
- for index, name in enumerate(networks.available_networks):
+ # instantiate a list to protect against concurrent modification
+ names = list(networks.available_networks)
+ for index, name in enumerate(names):
item = self.create_item(name, index)
-
if item is not None:
yield item
diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js
index 652f07ac..bff1aced 100644
--- a/extensions-builtin/mobile/javascript/mobile.js
+++ b/extensions-builtin/mobile/javascript/mobile.js
@@ -12,6 +12,8 @@ function isMobile() {
}
function reportWindowSize() {
+ if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
+
var currentlyMobile = isMobile();
if (currentlyMobile == isSetupForMobile) return;
isSetupForMobile = currentlyMobile;
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
index 5803daea..d680daf5 100644
--- a/javascript/dragdrop.js
+++ b/javascript/dragdrop.js
@@ -119,7 +119,7 @@ window.addEventListener('paste', e => {
}
const firstFreeImageField = visibleImageFields
- .filter(el => el.querySelector('input[type=file]'))?.[0];
+ .filter(el => !el.querySelector('img'))?.[0];
dropReplaceImage(
firstFreeImageField ?
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index 8906c892..688c2f11 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -18,37 +18,43 @@ function keyupEditAttention(event) {
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
- let beforeParenClose = before.lastIndexOf(CLOSE);
- while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
- beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
- beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
- }
+
+ let beforeClosingParen = before.lastIndexOf(CLOSE);
+ if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false;
// Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
- let afterParenOpen = after.indexOf(OPEN);
- while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
- afterParen = after.indexOf(CLOSE, afterParen + 1);
- afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
- }
- if (beforeParen === -1 || afterParen === -1) return false;
+
+ let afterOpeningParen = after.indexOf(OPEN);
+ if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;
// Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
- const lastColon = parenContent.lastIndexOf(":");
- selectionStart = beforeParen + 1;
- selectionEnd = selectionStart + lastColon;
+ if (/.*:-?[\d.]+/s.test(parenContent)) {
+ const lastColon = parenContent.lastIndexOf(":");
+ selectionStart = beforeParen + 1;
+ selectionEnd = selectionStart + lastColon;
+ } else {
+ selectionStart = beforeParen + 1;
+ selectionEnd = selectionStart + parenContent.length;
+ }
+
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
function selectCurrentWord() {
if (selectionStart !== selectionEnd) return false;
- const delimiters = opts.keyedit_delimiters + " \r\n\t";
+ const whitespace_delimiters = {"Tab": "\t", "Carriage Return": "\r", "Line Feed": "\n"};
+ let delimiters = opts.keyedit_delimiters;
+
+ for (let i of opts.keyedit_delimiters_whitespace) {
+ delimiters += whitespace_delimiters[i];
+ }
- // seek backward until to find beggining
+ // seek backward to find beginning
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
@@ -63,7 +69,7 @@ function keyupEditAttention(event) {
}
// If the user hasn't selected anything, let's select their current parenthesis block or word
- if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
+ if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) {
selectCurrentWord();
}
@@ -71,33 +77,54 @@ function keyupEditAttention(event) {
var closeCharacter = ')';
var delta = opts.keyedit_precision_attention;
+ var start = selectionStart > 0 ? text[selectionStart - 1] : "";
+ var end = text[selectionEnd];
- if (selectionStart > 0 && text[selectionStart - 1] == '<') {
+ if (start == '<') {
closeCharacter = '>';
delta = opts.keyedit_precision_extra;
- } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
+ } else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis)))
+ let numParen = 0;
+
+ while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {
+ numParen++;
+ }
+ if (start == "[") {
+ weight = (1 / 1.1) ** numParen;
+ } else {
+ weight = 1.1 ** numParen;
+ }
+
+ weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;
+
+ text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen);
+ selectionStart -= numParen - 1;
+ selectionEnd -= numParen - 1;
+ } else if (start != '(') {
// do not include spaces at the end
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
- selectionEnd -= 1;
+ selectionEnd--;
}
+
if (selectionStart == selectionEnd) {
return;
}
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
- selectionStart += 1;
- selectionEnd += 1;
+ selectionStart++;
+ selectionEnd++;
}
- var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+ if (text[selectionEnd] != ':') return;
+ var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+ var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength));
if (isNaN(weight)) return;
weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12));
- if (String(weight).length == 1) weight += ".0";
+ if (Number.isInteger(weight)) weight += ".0";
if (closeCharacter == ')' && weight == 1) {
var endParenPos = text.substring(selectionEnd).indexOf(')');
@@ -105,7 +132,7 @@ function keyupEditAttention(event) {
selectionStart--;
selectionEnd--;
} else {
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength);
}
target.focus();
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 493f31af..a1bf29a8 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) {
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
+ var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
+ var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
- sort.dataset.sortkey = 'sortDefault';
tabs.appendChild(searchDiv);
tabs.appendChild(sort);
tabs.appendChild(sortOrder);
@@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) {
elem.style.display = visible ? "" : "none";
});
+
+ applySort();
};
var applySort = function() {
+ var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
+
var reverse = sortOrder.classList.contains("sortReverse");
- var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
- sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
- var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
- if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
+ var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
+ sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
+ var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
+
+ if (sortKeyStore == sort.dataset.sortkey) {
return;
}
-
sort.dataset.sortkey = sortKeyStore;
- var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
cards.forEach(function(card) {
card.originalParentElement = card.parentElement;
});
@@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) {
};
search.addEventListener("input", applyFilter);
- applyFilter();
- ["change", "blur", "click"].forEach(function(evt) {
- sort.querySelector("input").addEventListener(evt, applySort);
- });
sortOrder.addEventListener("click", function() {
sortOrder.classList.toggle("sortReverse");
applySort();
});
+ applyFilter();
+ extraNetworksApplySort[tabname] = applySort;
extraNetworksApplyFilter[tabname] = applyFilter;
var showDirsUpdate = function() {
@@ -109,11 +111,47 @@ function setupExtraNetworksForTab(tabname) {
showDirsUpdate();
}
+function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
+ if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout
+
+ var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');
+ var prompt = gradioApp().getElementById(tabname + '_prompt_row');
+ var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');
+ var elem = id ? gradioApp().getElementById(id) : null;
+
+ if (showNegativePrompt && elem) {
+ elem.insertBefore(negPrompt, elem.firstChild);
+ } else {
+ promptContainer.insertBefore(negPrompt, promptContainer.firstChild);
+ }
+
+ if (showPrompt && elem) {
+ elem.insertBefore(prompt, elem.firstChild);
+ } else {
+ promptContainer.insertBefore(prompt, promptContainer.firstChild);
+ }
+}
+
+
+function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
+ extraNetworksMovePromptToTab(tabname, '', false, false);
+}
+
+function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
+ extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
+
+}
+
function applyExtraNetworkFilter(tabname) {
setTimeout(extraNetworksApplyFilter[tabname], 1);
}
+function applyExtraNetworkSort(tabname) {
+ setTimeout(extraNetworksApplySort[tabname], 1);
+}
+
var extraNetworksApplyFilter = {};
+var extraNetworksApplySort = {};
var activePromptTextarea = {};
function setupExtraNetworks() {
@@ -140,14 +178,15 @@ function setupExtraNetworks() {
onUiLoaded(setupExtraNetworks);
-var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/;
-var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
+var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
+var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var m = text.match(re_extranet);
var replaced = false;
var newTextareaText;
if (m) {
+ var extraTextBeforeNet = opts.extra_networks_add_text_separator;
var extraTextAfterNet = m[2];
var partToSearch = m[1];
var foundAtPosition = -1;
@@ -161,8 +200,13 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return found;
});
- if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
- newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
+ if (foundAtPosition >= 0) {
+ if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
+ newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
+ }
+ if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
+ newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition);
+ }
}
} else {
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
@@ -216,27 +260,24 @@ function extraNetworksSearchButton(tabs_id, event) {
var globalPopup = null;
var globalPopupInner = null;
+
function closePopup() {
if (!globalPopup) return;
-
globalPopup.style.display = "none";
}
+
function popup(contents) {
if (!globalPopup) {
globalPopup = document.createElement('div');
- globalPopup.onclick = closePopup;
globalPopup.classList.add('global-popup');
var close = document.createElement('div');
close.classList.add('global-popup-close');
- close.onclick = closePopup;
+ close.addEventListener("click", closePopup);
close.title = "Close";
globalPopup.appendChild(close);
globalPopupInner = document.createElement('div');
- globalPopupInner.onclick = function(event) {
- event.stopPropagation(); return false;
- };
globalPopupInner.classList.add('global-popup-inner');
globalPopup.appendChild(globalPopupInner);
@@ -335,7 +376,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
function extraNetworksRefreshSingleCard(page, tabname, name) {
requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
if (data && data.html) {
- var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function
+ var card = gradioApp().querySelector(`#${tabname}_${page.replace(" ", "_")}_cards > .card[data-name="${name}"]`);
var newDiv = document.createElement('DIV');
newDiv.innerHTML = data.html;
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index c21d396e..e4dae91b 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage");
if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button();
-
- if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
+ let preview = gradioApp().querySelectorAll('.livePreview > img');
+ if (preview.length > 0) {
+ // show preview image if available
+ modalImage.src = preview[preview.length - 1].src;
+ } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') {
const modal = gradioApp().getElementById("lightboxModal");
diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js
index f2839852..7570309a 100644
--- a/javascript/inputAccordion.js
+++ b/javascript/inputAccordion.js
@@ -1,37 +1,68 @@
-var observerAccordionOpen = new MutationObserver(function(mutations) {
- mutations.forEach(function(mutationRecord) {
- var elem = mutationRecord.target;
- var open = elem.classList.contains('open');
+function inputAccordionChecked(id, checked) {
+ var accordion = gradioApp().getElementById(id);
+ accordion.visibleCheckbox.checked = checked;
+ accordion.onVisibleCheckboxChange();
+}
- var accordion = elem.parentNode;
- accordion.classList.toggle('input-accordion-open', open);
+function setupAccordion(accordion) {
+ var labelWrap = accordion.querySelector('.label-wrap');
+ var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
+ var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
+ var span = labelWrap.querySelector('span');
+ var linked = true;
- var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
- checkbox.checked = open;
- updateInput(checkbox);
+ var isOpen = function() {
+ return labelWrap.classList.contains('open');
+ };
- var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
- if (extra) {
- extra.style.display = open ? "" : "none";
- }
+ var observerAccordionOpen = new MutationObserver(function(mutations) {
+ mutations.forEach(function(mutationRecord) {
+ accordion.classList.toggle('input-accordion-open', isOpen());
+
+ if (linked) {
+ accordion.visibleCheckbox.checked = isOpen();
+ accordion.onVisibleCheckboxChange();
+ }
+ });
});
-});
+ observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
-function inputAccordionChecked(id, checked) {
- var label = gradioApp().querySelector('#' + id + " .label-wrap");
- if (label.classList.contains('open') != checked) {
- label.click();
+ if (extra) {
+ labelWrap.insertBefore(extra, labelWrap.lastElementChild);
}
+
+ accordion.onChecked = function(checked) {
+ if (isOpen() != checked) {
+ labelWrap.click();
+ }
+ };
+
+ var visibleCheckbox = document.createElement('INPUT');
+ visibleCheckbox.type = 'checkbox';
+ visibleCheckbox.checked = isOpen();
+ visibleCheckbox.id = accordion.id + "-visible-checkbox";
+ visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox";
+ span.insertBefore(visibleCheckbox, span.firstChild);
+
+ accordion.visibleCheckbox = visibleCheckbox;
+ accordion.onVisibleCheckboxChange = function() {
+ if (linked && isOpen() != visibleCheckbox.checked) {
+ labelWrap.click();
+ }
+
+ gradioCheckbox.checked = visibleCheckbox.checked;
+ updateInput(gradioCheckbox);
+ };
+
+ visibleCheckbox.addEventListener('click', function(event) {
+ linked = false;
+ event.stopPropagation();
+ });
+ visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);
}
onUiLoaded(function() {
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
- var labelWrap = accordion.querySelector('.label-wrap');
- observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
-
- var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
- if (extra) {
- labelWrap.insertBefore(extra, labelWrap.lastElementChild);
- }
+ setupAccordion(accordion);
}
});
diff --git a/javascript/notification.js b/javascript/notification.js
index 6d799561..3ee972ae 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -26,7 +26,11 @@ onAfterUiUpdate(function() {
lastHeadImg = headImg;
// play notification sound if available
- gradioApp().querySelector('#audio_notification audio')?.play();
+ const notificationAudio = gradioApp().querySelector('#audio_notification audio');
+ if (notificationAudio) {
+ notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
+ notificationAudio.play();
+ }
if (document.hasFocus()) return;
diff --git a/javascript/settings.js b/javascript/settings.js
new file mode 100644
index 00000000..4e79ec00
--- /dev/null
+++ b/javascript/settings.js
@@ -0,0 +1,46 @@
+let settingsExcludeTabsFromShowAll = {
+ settings_tab_defaults: 1,
+ settings_tab_sysinfo: 1,
+ settings_tab_actions: 1,
+ settings_tab_licenses: 1,
+};
+
+function settingsShowAllTabs() {
+ gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
+ if (settingsExcludeTabsFromShowAll[elem.id]) return;
+
+ elem.style.display = "block";
+ });
+}
+
+function settingsShowOneTab() {
+ gradioApp().querySelector('#settings_show_one_page').click();
+}
+
+onUiLoaded(function() {
+ var edit = gradioApp().querySelector('#settings_search');
+ var editTextarea = gradioApp().querySelector('#settings_search > label > input');
+ var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages');
+ var settings_tabs = gradioApp().querySelector('#settings div');
+
+ onEdit('settingsSearch', editTextarea, 250, function() {
+ var searchText = (editTextarea.value || "").trim().toLowerCase();
+
+ gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) {
+ var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1;
+ elem.style.display = visible ? "" : "none";
+ });
+
+ if (searchText != "") {
+ settingsShowAllTabs();
+ } else {
+ settingsShowOneTab();
+ }
+ });
+
+ settings_tabs.insertBefore(edit, settings_tabs.firstChild);
+ settings_tabs.appendChild(buttonShowAllPages);
+
+
+ buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
+});
diff --git a/javascript/token-counters.js b/javascript/token-counters.js
index 9d81a723..2ecc7d91 100644
--- a/javascript/token-counters.js
+++ b/javascript/token-counters.js
@@ -1,10 +1,9 @@
-let promptTokenCountDebounceTime = 800;
-let promptTokenCountTimeouts = {};
-var promptTokenCountUpdateFunctions = {};
+let promptTokenCountUpdateFunctions = {};
function update_txt2img_tokens(...args) {
// Called from Gradio
update_token_counter("txt2img_token_button");
+ update_token_counter("txt2img_negative_token_button");
if (args.length == 2) {
return args[0];
}
@@ -14,6 +13,7 @@ function update_txt2img_tokens(...args) {
function update_img2img_tokens(...args) {
// Called from Gradio
update_token_counter("img2img_token_button");
+ update_token_counter("img2img_negative_token_button");
if (args.length == 2) {
return args[0];
}
@@ -21,16 +21,7 @@ function update_img2img_tokens(...args) {
}
function update_token_counter(button_id) {
- if (opts.disable_token_counters) {
- return;
- }
- if (promptTokenCountTimeouts[button_id]) {
- clearTimeout(promptTokenCountTimeouts[button_id]);
- }
- promptTokenCountTimeouts[button_id] = setTimeout(
- () => gradioApp().getElementById(button_id)?.click(),
- promptTokenCountDebounceTime,
- );
+ promptTokenCountUpdateFunctions[button_id]?.();
}
@@ -69,10 +60,11 @@ function setupTokenCounting(id, id_counter, id_button) {
prompt.parentElement.insertBefore(counter, prompt);
prompt.parentElement.style.position = "relative";
- promptTokenCountUpdateFunctions[id] = function() {
- update_token_counter(id_button);
- };
- textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
+ var func = onEdit(id, textarea, 800, function() {
+ gradioApp().getElementById(id_button)?.click();
+ });
+ promptTokenCountUpdateFunctions[id] = func;
+ promptTokenCountUpdateFunctions[id_button] = func;
}
function setupTokenCounters() {
diff --git a/javascript/ui.js b/javascript/ui.js
index bedcbf3e..2e262602 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -263,21 +263,6 @@ onAfterUiUpdate(function() {
json_elem.parentElement.style.display = "none";
setupTokenCounters();
-
- var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
- var settings_tabs = gradioApp().querySelector('#settings div');
- if (show_all_pages && settings_tabs) {
- settings_tabs.appendChild(show_all_pages);
- show_all_pages.onclick = function() {
- gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
- if (elem.id == "settings_tab_licenses") {
- return;
- }
-
- elem.style.display = "block";
- });
- };
- }
});
onOptionsChanged(function() {
@@ -366,3 +351,20 @@ function switchWidthHeight(tabname) {
updateInput(height);
return [];
}
+
+
+var onEditTimers = {};
+
+// calls func after afterMs milliseconds has passed since the input elem has beed enited by user
+function onEdit(editId, elem, afterMs, func) {
+ var edited = function() {
+ var existingTimer = onEditTimers[editId];
+ if (existingTimer) clearTimeout(existingTimer);
+
+ onEditTimers[editId] = setTimeout(func, afterMs);
+ };
+
+ elem.addEventListener("input", edited);
+
+ return edited;
+}
diff --git a/modules/api/api.py b/modules/api/api.py
index e6edffe7..09083874 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -17,19 +17,18 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
-from PIL import PngImagePlugin,Image
-from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
+from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
-from typing import Dict, List, Any
+from typing import Any
import piexif
import piexif.helper
from contextlib import closing
@@ -103,7 +102,8 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
-
+ if isinstance(image, str):
+ return image
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
@@ -221,15 +221,15 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
- self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
- self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
- self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
- self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
- self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
- self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
- self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
- self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
+ self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
@@ -242,7 +242,8 @@ class Api:
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
- self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+ self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
+ self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@@ -473,9 +474,6 @@ class Api:
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest):
- if(not req.image.strip()):
- return models.PNGInfoResponse(info="")
-
image = decode_base64_to_image(req.image.strip())
if image is None:
return models.PNGInfoResponse(info="")
@@ -484,9 +482,10 @@ class Api:
if geninfo is None:
geninfo = ""
- items = {**{'parameters': geninfo}, **items}
+ params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
+ script_callbacks.infotext_pasted_callback(geninfo, params)
- return models.PNGInfoResponse(info=geninfo, items=items)
+ return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
@@ -541,12 +540,12 @@ class Api:
return {}
def unloadapi(self):
- unload_model_weights()
+ sd_models.unload_model_weights()
return {}
def reloadapi(self):
- reload_model_weights()
+ sd_models.send_model_to_device(shared.sd_model)
return {}
@@ -564,9 +563,9 @@ class Api:
return options
- def set_config(self, req: Dict[str, Any]):
+ def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
- if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+ if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():
@@ -770,6 +769,25 @@ class Api:
cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda)
+ def get_extensions_list(self):
+ from modules import extensions
+ extensions.list_extensions()
+ ext_list = []
+ for ext in extensions.extensions:
+ ext: extensions.Extension
+ ext.read_info_from_repo()
+ if ext.remote is not None:
+ ext_list.append({
+ "name": ext.name,
+ "remote": ext.remote,
+ "branch": ext.branch,
+ "commit_hash":ext.commit_hash,
+ "commit_date":ext.commit_date,
+ "version":ext.version,
+ "enabled":ext.enabled
+ })
+ return ext_list
+
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
diff --git a/modules/api/models.py b/modules/api/models.py
index 6a574771..a0d80af8 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,12 +1,10 @@
import inspect
from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
-from typing_extensions import Literal
+from typing import Any, Optional, Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser
-from typing import Dict, List
API_NOT_ALLOWED = [
"self",
@@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@@ -168,17 +166,18 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
- items: dict = Field(title="Items", description="An object containing all the info the image had")
+ items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
+ parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
@@ -232,8 +231,8 @@ FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
- aliases: List[str] = Field(title="Aliases")
- options: Dict[str, str] = Field(title="Options")
+ aliases: list[str] = Field(title="Aliases")
+ options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
@@ -284,8 +283,8 @@ class EmbeddingItem(BaseModel):
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
- loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+ loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
@@ -303,11 +302,20 @@ class ScriptArg(BaseModel):
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
- choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+ choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
- args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+ args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+
+class ExtensionItem(BaseModel):
+ name: str = Field(title="Name", description="Extension name")
+ remote: str = Field(title="Remote", description="Extension Repository URL")
+ branch: str = Field(title="Branch", description="Extension Repository Branch")
+ commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
+ version: str = Field(title="Version", description="Extension Version")
+ commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
+ enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index aab62286..a9fb9bfa 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -90,7 +90,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
-parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
+parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
@@ -107,13 +107,14 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
-parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
+parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
-parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
-parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
+parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
diff --git a/modules/config_states.py b/modules/config_states.py
index b766aef1..651793c7 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -4,7 +4,6 @@ Supports saving and restoring webui and extensions from a known working set of c
import os
import json
-import time
import tqdm
from datetime import datetime
@@ -38,7 +37,7 @@ def list_config_states():
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states:
- timestamp = time.asctime(time.gmtime(cs["created_at"]))
+ timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
name = cs.get("name", "Config")
full_name = f"{name}: {timestamp}"
all_config_states[full_name] = cs
diff --git a/modules/devices.py b/modules/devices.py
index c01f0602..1d4eb563 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -60,7 +60,8 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
+ device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
+ if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d39f2eba..0a606515 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -9,7 +9,7 @@ from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks, processing
from PIL import Image
-re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
+re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 8e0f13bd..01d668ec 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -9,6 +9,7 @@ from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN"
user_path = None
model_path = os.path.join(paths.models_path, model_dir)
+model_file_path = None
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None
@@ -17,6 +18,7 @@ loaded_gfpgan_model = None
def gfpgann():
global loaded_gfpgan_model
global model_path
+ global model_file_path
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
@@ -24,17 +26,24 @@ def gfpgann():
if gfpgan_constructor is None:
return None
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
+
if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
- latest_file = max(models, key=os.path.getctime)
+ gfp_models = []
+ for item in models:
+ if 'GFPGAN' in os.path.basename(item):
+ gfp_models.append(item)
+ latest_file = max(gfp_models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
+
if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan
+ model_file_path = model_file
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model
@@ -77,19 +86,25 @@ def setup_model(dirname):
global user_path
global have_gfpgan
global gfpgan_constructor
+ global model_file_path
+
+ facexlib_path = model_path
+
+ if dirname is not None:
+ facexlib_path = dirname
load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py
index e537c1df..b55f0640 100644
--- a/modules/gitpython_hack.py
+++ b/modules/gitpython_hack.py
@@ -23,7 +23,7 @@ class Git(git.Git):
)
return self._parse_object_header(ret)
- def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
+ def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
# Not really streaming, per se; this buffers the entire object in memory.
# Shouldn't be a problem for our use case, since we're only using this for
# object headers (commit objects).
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 70f1cbd2..be3e4648 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0
@@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
diff --git a/modules/images.py b/modules/images.py
index eb644733..daf4eebe 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
})
piexif.insert(exif_bytes, filename)
+ elif extension.lower() == ".gif":
+ image.save(filename, format=image_format, comment=geninfo)
else:
image.save(filename, format=image_format, quality=opts.jpeg_quality)
@@ -661,7 +663,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
- os.replace(temp_file_path, filename_without_extension + extension)
+ filename = filename_without_extension + extension
+ if shared.opts.save_images_replace_action != "Replace":
+ n = 0
+ while os.path.exists(filename):
+ n += 1
+ filename = f"{filename_without_extension}-{n}{extension}"
+ os.replace(temp_file_path, filename)
fullfn_without_extension, extension = os.path.splitext(params.filename)
if hasattr(os, 'statvfs'):
@@ -718,7 +726,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
geninfo = items.pop('parameters', None)
if "exif" in items:
- exif = piexif.load(items["exif"])
+ exif_data = items["exif"]
+ try:
+ exif = piexif.load(exif_data)
+ except OSError:
+ # memory / exif was not valid so piexif tried to read from a file
+ exif = None
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
@@ -728,6 +741,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
if exif_comment:
items['exif comment'] = exif_comment
geninfo = exif_comment
+ elif "comment" in items: # for gif
+ geninfo = items["comment"].decode('utf8', errors="ignore")
for field in IGNORED_INFO_KEYS:
items.pop(field, None)
diff --git a/modules/img2img.py b/modules/img2img.py
index 1519e132..52cb577a 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -10,6 +10,7 @@ from modules import images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
+from modules.sd_models import get_closet_checkpoint_match
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
@@ -41,7 +42,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
cfg_scale = p.cfg_scale
sampler_name = p.sampler_name
steps = p.steps
-
+ override_settings = p.override_settings
+ sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
@@ -104,15 +106,27 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
p.steps = int(parsed_parameters.get("Steps", steps))
+ model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
+ if model_info is not None:
+ p.override_settings['sd_model_checkpoint'] = model_info.name
+ elif sd_model_checkpoint_override:
+ p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
+ else:
+ p.override_settings.pop("sd_model_checkpoint", None)
+
+ if output_dir:
+ p.outpath_samples = output_dir
+ p.override_settings['save_to_dirs'] = False
+ p.override_settings['save_images_replace_action'] = "Add number suffix"
+ if p.n_iter > 1 or p.batch_size > 1:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
+ else:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
+
proc = modules.scripts.scripts_img2img.run(p, *args)
+
if proc is None:
- if output_dir:
- p.outpath_samples = output_dir
- p.override_settings['save_to_dirs'] = False
- if p.n_iter > 1 or p.batch_size > 1:
- p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
- else:
- p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
+ p.override_settings.pop('save_images_replace_action', None)
process_images(p)
@@ -189,7 +203,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.user = request.username
- if shared.cmd_opts.enable_console_prompts:
+ if shared.opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask:
diff --git a/modules/initialize.py b/modules/initialize.py
index f24f7637..ac95fc6f 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -151,8 +151,8 @@ def initialize_rest(*, reload_script_modules=False):
from modules import devices
devices.first_time_calculation()
-
- Thread(target=load_model).start()
+ if not shared.cmd_opts.skip_load_model_at_start:
+ Thread(target=load_model).start()
from modules import shared_items
shared_items.reload_hypernetworks()
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index 2894eee4..2e9b6d89 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -150,10 +150,14 @@ def dumpstacks():
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
+
+ from modules import shared
+
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
- dumpstacks()
+ if shared.opts.dump_stacks_on_signal:
+ dumpstacks()
os._exit(0)
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 6e54d063..8cdbafa5 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -64,7 +64,7 @@ Use --skip-python-version-check to suppress this warning.
@lru_cache()
def commit_hash():
try:
- return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
+ return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
except Exception:
return "<none>"
@@ -72,7 +72,7 @@ def commit_hash():
@lru_cache()
def git_tag():
try:
- return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
+ return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
try:
diff --git a/modules/localization.py b/modules/localization.py
index c1320288..108f792e 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -14,21 +14,24 @@ def list_localizations(dirname):
if ext.lower() != ".json":
continue
- localizations[fn] = os.path.join(dirname, file)
+ localizations[fn] = [os.path.join(dirname, file)]
for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename)
- localizations[fn] = file.path
+ if fn not in localizations:
+ localizations[fn] = []
+ localizations[fn].append(file.path)
def localization_js(current_localization_name: str) -> str:
- fn = localizations.get(current_localization_name, None)
+ fns = localizations.get(current_localization_name, None)
data = {}
- if fn is not None:
- try:
- with open(fn, "r", encoding="utf8") as file:
- data = json.load(file)
- except Exception:
- errors.report(f"Error loading localization from {fn}", exc_info=True)
+ if fns is not None:
+ for fn in fns:
+ try:
+ with open(fn, "r", encoding="utf8") as file:
+ data.update(json.load(file))
+ except Exception:
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/logging_config.py b/modules/logging_config.py
index 7db23d4b..79269875 100644
--- a/modules/logging_config.py
+++ b/modules/logging_config.py
@@ -1,16 +1,41 @@
import os
import logging
+try:
+ from tqdm.auto import tqdm
+
+ class TqdmLoggingHandler(logging.Handler):
+ def __init__(self, level=logging.INFO):
+ super().__init__(level)
+
+ def emit(self, record):
+ try:
+ msg = self.format(record)
+ tqdm.write(msg)
+ self.flush()
+ except Exception:
+ self.handleError(record)
+
+ TQDM_IMPORTED = True
+except ImportError:
+ # tqdm does not exist before first launch
+ # I will import once the UI finishes seting up the enviroment and reloads.
+ TQDM_IMPORTED = False
def setup_logging(loglevel):
if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
+ loghandlers = []
+
+ if TQDM_IMPORTED:
+ loghandlers.append(TqdmLoggingHandler())
+
if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=loghandlers
)
-
diff --git a/modules/options.py b/modules/options.py
index 758b1ce5..ab40aff7 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -210,6 +210,8 @@ class Options:
def add_option(self, key, info):
self.data_labels[key] = info
+ if key not in self.data:
+ self.data[key] = info.default
def reorder(self):
"""reorder settings so that all items related to section always go together"""
diff --git a/modules/paths.py b/modules/paths.py
index 25052339..187b9496 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,6 +1,6 @@
import os
import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
import modules.safe # noqa: F401
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 005a9b0a..89131a54 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -8,6 +8,7 @@ import shlex
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
+cwd = os.getcwd()
modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path)
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index cf04d38b..fd0c0cc9 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -78,7 +78,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
image_data.close()
devices.torch_gc()
-
+ shared.state.end()
return outputs, ui_common.plaintext_to_html(infotext), ''
diff --git a/modules/processing.py b/modules/processing.py
index e124e7f0..b0e240a4 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -142,7 +142,7 @@ class StableDiffusionProcessing:
overlay_images: list = None
eta: float = None
do_not_reload_embeddings: bool = False
- denoising_strength: float = 0
+ denoising_strength: float = None
ddim_discretize: str = None
s_min_uncond: float = None
s_churn: float = None
@@ -296,7 +296,7 @@ class StableDiffusionProcessing:
return conditioning
def edit_image_conditioning(self, source_image):
- conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
+ conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
return conditioning_image
@@ -533,6 +533,7 @@ class Processed:
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
+ self.version = program_version()
def js(self):
obj = {
@@ -567,6 +568,7 @@ class Processed:
"job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
+ "version": self.version,
}
return json.dumps(obj)
@@ -709,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.before_process(p)
- stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
+ stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@@ -884,6 +886,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
+ state.nextjob()
+
if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
@@ -956,7 +960,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- state.nextjob()
+ if not infotexts:
+ infotexts.append(Processed(p, []).infotext(p, 0))
p.color_corrections = None
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index 6b6ff987..dc9c2da5 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -29,8 +29,8 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
else:
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
- random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
- reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
+ random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time")
+ reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized")
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 334efeef..cba13455 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -2,10 +2,9 @@ from __future__ import annotations
import re
from collections import namedtuple
-from typing import List
import lark
-# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
+# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
@@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
- self.schedules: List[ScheduledPromptConditioning] = schedules
+ self.schedules: list[ScheduledPromptConditioning] = schedules
self.weight: float = weight
class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
- self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+ self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
@@ -278,7 +277,7 @@ class DictWithShape(dict):
return self["crossattn"].shape
-def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
+def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
is_dict = isinstance(param, dict)
diff --git a/modules/restart.py b/modules/restart.py
index 18eacaf3..2dd6493b 100644
--- a/modules/restart.py
+++ b/modules/restart.py
@@ -14,7 +14,9 @@ def is_restartable() -> bool:
def restart_program() -> None:
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
- (Path(script_path) / "tmp" / "restart").touch()
+ tmpdir = Path(script_path) / "tmp"
+ tmpdir.mkdir(parents=True, exist_ok=True)
+ (tmpdir / "restart").touch()
stop_program()
diff --git a/modules/rng.py b/modules/rng.py
index 9e8ba2ee..8934d39b 100644
--- a/modules/rng.py
+++ b/modules/rng.py
@@ -110,7 +110,7 @@ class ImageRNG:
self.is_first = True
def first(self):
- noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
+ noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
xs = []
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index c99695eb..9ed7ad21 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,7 +1,7 @@
import inspect
import os
from collections import namedtuple
-from typing import Optional, Dict, Any
+from typing import Optional, Any
from fastapi import FastAPI
from gradio import Blocks
@@ -258,7 +258,7 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid')
-def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']:
try:
c.callback(infotext, params)
@@ -449,7 +449,7 @@ def on_infotext_pasted(callback):
"""register a function to be called before applying an infotext.
The callback is called with two arguments:
- infotext: str - raw infotext.
- - result: Dict[str, any] - parsed infotext parameters.
+ - result: dict[str, any] - parsed infotext parameters.
"""
add_callback(callback_map['callbacks_infotext_pasted'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index e8518ad0..5c6e0226 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -491,11 +491,15 @@ class ScriptRunner:
arg_info = api_models.ScriptArg(label=control.label or "")
- for field in ("value", "minimum", "maximum", "step", "choices"):
+ for field in ("value", "minimum", "maximum", "step"):
v = getattr(control, field, None)
if v is not None:
setattr(arg_info, field, v)
+ choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string
+ if choices is not None:
+ arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]
+
api_args.append(arg_info)
script.api_info = api_models.ScriptInfo(
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 592f0055..0157e19f 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -2,14 +2,15 @@ import torch
from torch.nn.functional import silu
from types import MethodType
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
+import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -37,6 +38,8 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
+ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
+sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
@@ -181,6 +184,20 @@ class StableDiffusionModelHijack:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
+ def convert_sdxl_to_ssd(self, m):
+ """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
+
+ delattr(m.model.diffusion_model.middle_block, '1')
+ delattr(m.model.diffusion_model.middle_block, '2')
+ for i in ['9', '8', '7', '6', '5', '4']:
+ delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
+ delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
+ devices.torch_gc()
+
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
@@ -208,7 +225,7 @@ class StableDiffusionModelHijack:
else:
m.cond_stage_model = conditioner
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
@@ -239,10 +256,17 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
- if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
- ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+ import modules.models.diffusion.ddpm_edit
+
+ if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
+ sd_unet.original_forward = sgm_original_forward
+ else:
+ sd_unet.original_forward = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
@@ -279,7 +303,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+ sd_unet.original_forward = None
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 930d0bee..841402e8 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,22 +1,22 @@
import collections
import os.path
import sys
-import gc
import threading
import torch
import re
import safetensors.torch
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
+import numpy as np
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -49,11 +49,12 @@ class CheckpointInfo:
def __init__(self, filename):
self.filename = filename
abspath = os.path.abspath(filename)
+ abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
+ name = abspath.replace(abs_ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
@@ -129,9 +130,12 @@ except Exception:
def setup_model():
+ """called once at startup to do various one-time tasks related to SD models"""
+
os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload()
+ patch_given_betas()
def checkpoint_tiles(use_short=False):
@@ -309,6 +313,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
+ # move to end as latest
+ checkpoints_loaded.move_to_end(checkpoint_info)
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
@@ -346,16 +352,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
-
+ model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
- model.load_state_dict(state_dict, strict=False)
- timer.record("apply weights to model")
+ if model.is_ssd:
+ sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = state_dict
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
+
+ model.load_state_dict(state_dict, strict=False)
+ timer.record("apply weights to model")
del state_dict
@@ -453,6 +462,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+def patch_given_betas():
+ import ldm.models.diffusion.ddpm
+
+ def patched_register_schedule(*args, **kwargs):
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
+
+ if isinstance(args[1], ListConfig):
+ args = (args[0], np.array(args[1]), *args[2:])
+
+ original_register_schedule(*args, **kwargs)
+
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
+
+
def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"):
@@ -777,17 +800,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- timer = Timer()
-
- if model_data.sd_model:
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- model_data.sd_model = None
- sd_model = None
- gc.collect()
- devices.torch_gc()
-
- print(f"Unloaded weights {timer.summary()}.")
+ send_model_to_cpu(sd_model or shared.sd_model)
return sd_model
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 08dd03f1..deab2f6e 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
-
+config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
"""
@@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
+
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
+ if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
+ return config_alt_diffusion_m18
return config_alt_diffusion
return config_default
diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py
index 5ffd2f4f..f911fbb6 100644
--- a/modules/sd_models_types.py
+++ b/modules/sd_models_types.py
@@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
"""structure with additional information about the file with model's weights"""
is_sdxl: bool
- """True if the model's architecture is SDXL"""
+ """True if the model's architecture is SDXL or SSD"""
+
+ is_ssd: bool
+ """True if the model is SSD"""
is_sd2: bool
"""True if the model's architecture is SD 2.x"""
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 5525cfbc..6a7bc9e2 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -1,11 +1,11 @@
import torch.nn
-import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None
current_unet = None
+original_forward = None
def list_unets():
@@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
- return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+ return original_forward(self, x, timesteps, context, *args, **kwargs)
diff --git a/modules/shared_cmd_options.py b/modules/shared_cmd_options.py
index dd93f520..c9626667 100644
--- a/modules/shared_cmd_options.py
+++ b/modules/shared_cmd_options.py
@@ -14,5 +14,5 @@ if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
else:
cmd_opts, _ = parser.parse_known_args()
-
-cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access
+cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name])
+cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 84d69c8d..5024b426 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -44,9 +44,9 @@ def refresh_unet_list():
modules.sd_unet.list_unets()
-def list_checkpoint_tiles():
+def list_checkpoint_tiles(use_short=False):
import modules.sd_models
- return modules.sd_models.checkpoint_tiles()
+ return modules.sd_models.checkpoint_tiles(use_short)
def refresh_checkpoints():
@@ -67,6 +67,8 @@ def reload_hypernetworks():
ui_reorder_categories_builtin_items = [
+ "prompt",
+ "image",
"inpaint",
"sampler",
"accordions",
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 00b273fa..f1003f21 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -26,7 +26,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
-
+ "save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
@@ -62,6 +62,9 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
+
+ "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
+ "notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -100,6 +103,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), {
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
+ "enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
@@ -109,6 +113,7 @@ options_templates.update(options_section(('system', "System"), {
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
+ "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
}))
options_templates.update(options_section(('API', "API"), {
@@ -133,7 +138,7 @@ options_templates.update(options_section(('training', "Training"), {
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
@@ -230,6 +235,8 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
+ "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
+ "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
@@ -255,15 +262,20 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+ "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"),
+ "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
+ "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
+ "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
+ "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
+ "compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
}))
@@ -305,8 +317,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
- 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
- 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
@@ -329,4 +341,3 @@ options_templates.update(options_section((None, "Hidden options"), {
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
-
diff --git a/modules/shared_state.py b/modules/shared_state.py
index d272ee5b..a68789cc 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -103,6 +103,7 @@ class State:
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
+ self.time_start = time.time()
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
@@ -114,7 +115,6 @@ class State:
self.skipped = False
self.interrupted = False
self.textinfo = None
- self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index ae4ee4bb..4cb561ef 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -15,7 +15,7 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
-from typing import Optional, NamedTuple, List
+from typing import Optional, NamedTuple
def narrow_trunc(
@@ -97,7 +97,7 @@ def _query_chunk_attention(
)
return summarize_chunk(query, key_chunk, value_chunk)
- chunks: List[AttnChunk] = [
+ chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index aa79dc09..04dda585 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -181,40 +181,7 @@ class EmbeddingDatabase:
else:
return
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
- vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
- shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
- vectors = data['clip_g'].shape[0]
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vectors
- embedding.shape = shape
- embedding.filename = path
- embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
+ embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
return fn
+def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
+ if 'string_to_param' in data: # textual inversion embeddings
+ param_dict = data['string_to_param']
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
+ vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+ vectors = data['clip_g'].shape[0]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vectors
+ embedding.shape = shape
+
+ if filepath:
+ embedding.filename = filepath
+ embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
+
+ return embedding
+
+
def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
@@ -386,7 +392,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
from modules import processing
save_embedding_every = save_embedding_every or 0
@@ -590,7 +596,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 1ee592ad..e4e18ceb 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -3,7 +3,7 @@ from contextlib import closing
import modules.scripts
from modules import processing
from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.shared import opts, cmd_opts
+from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
import gradio as gr
@@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.user = request.username
- if cmd_opts.enable_console_prompts:
+ if shared.opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
with closing(p):
diff --git a/modules/ui.py b/modules/ui.py
index 579bab98..ba0d8542 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -4,6 +4,7 @@ import os
import sys
from functools import reduce
import warnings
+from contextlib import ExitStack
import gradio as gr
import gradio.utils
@@ -12,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -25,7 +26,6 @@ import modules.hypernetworks.ui as hypernetworks_ui
import modules.textual_inversion.ui as textual_inversion_ui
import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
-import modules.images
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.generation_parameters_copypaste import image_from_url_text
@@ -151,11 +151,15 @@ def connect_clear_prompt(button):
)
-def update_token_counter(text, steps):
+def update_token_counter(text, steps, *, is_positive=True):
try:
text, _ = extra_networks.parse_prompt(text)
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ if is_positive:
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ else:
+ prompt_flat_list = [text]
+
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
@@ -169,76 +173,9 @@ def update_token_counter(text, steps):
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
-class Toprow:
- """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
+def update_negative_prompt_token_counter(text, steps):
+ return update_token_counter(text, steps, is_positive=False)
- def __init__(self, is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
- self.id_part = id_part
-
- with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
- with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
- self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
- self.button_interrogate = None
- self.button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_classes="interrogate-col"):
- self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
- self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
- self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
- self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- self.skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
-
- self.interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- with gr.Row(elem_id=f"{id_part}_tools"):
- self.paste = ToolButton(value=paste_symbol, elem_id="paste")
- self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
-
- self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
- self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
- self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
- self.clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[self.prompt, self.negative_prompt],
- outputs=[self.prompt, self.negative_prompt],
- )
-
- self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
-
- self.prompt_img.change(
- fn=modules.images.image_data,
- inputs=[self.prompt_img],
- outputs=[self.prompt, self.prompt_img],
- show_progress=False,
- )
def setup_progressbar(*args, **kwargs):
@@ -278,8 +215,8 @@ def apply_setting(key, value):
return getattr(opts, key)
-def create_output_panel(tabname, outdir):
- return ui_common.create_output_panel(tabname, outdir)
+def create_output_panel(tabname, outdir, toprow=None):
+ return ui_common.create_output_panel(tabname, outdir, toprow)
def create_sampler_and_steps_selection(choices, tabname):
@@ -326,7 +263,7 @@ def create_ui():
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- toprow = Toprow(is_img2img=False)
+ toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
dummy_component = gr.Label(visible=False)
@@ -334,10 +271,17 @@ def create_ui():
extra_tabs.__enter__()
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
- with gr.Column(variant='compact', elem_id="txt2img_settings"):
+ with ExitStack() as stack:
+ if shared.opts.txt2img_settings_accordion:
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
+ stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
+
scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
+ if category == "prompt":
+ toprow.create_inline_toprow_prompts()
+
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
@@ -348,7 +292,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
@@ -432,7 +376,7 @@ def create_ui():
show_progress=False,
)
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
@@ -533,7 +477,7 @@ def create_ui():
]
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
- toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
@@ -544,13 +488,17 @@ def create_ui():
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- toprow = Toprow(is_img2img=True)
+ toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
extra_tabs.__enter__()
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
- with gr.Column(variant='compact', elem_id="img2img_settings"):
+ with ExitStack() as stack:
+ if shared.opts.img2img_settings_accordion:
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
+ stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
+
copy_image_buttons = []
copy_image_destinations = {}
@@ -567,85 +515,89 @@ def create_ui():
button = gr.Button(title)
copy_image_buttons.append((button, name, elem))
- with gr.Tabs(elem_id="mode_img2img"):
- img2img_selected_tab = gr.State(0)
-
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
- add_copy_image_controls('img2img', init_img)
-
- with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
- add_copy_image_controls('sketch', sketch)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
- add_copy_image_controls('inpaint', init_img_with_mask)
-
- with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
- inpaint_color_sketch_orig = gr.State(None)
- add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
-
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
-
- inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
-
- with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
-
- with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
- hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(
- "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
- "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
- f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
- f"{hidden}</p>"
- )
- img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
- img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
- with gr.Accordion("PNG info", open=False):
- img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
- img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
- img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
-
- img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
-
- for i, tab in enumerate(img2img_tabs):
- tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
-
- def copy_image(img):
- if isinstance(img, dict) and 'image' in img:
- return img['image']
-
- return img
-
- for button, name, elem in copy_image_buttons:
- button.click(
- fn=copy_image,
- inputs=[elem],
- outputs=[copy_image_destinations[name]],
- )
- button.click(
- fn=lambda: None,
- _js=f"switch_to_{name.replace(' ', '_')}",
- inputs=[],
- outputs=[],
- )
-
- with FormRow():
- resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
-
scripts.scripts_img2img.prepare_ui()
for category in ordered_ui_categories():
+ if category == "prompt":
+ toprow.create_inline_toprow_prompts()
+
+ if category == "image":
+ with gr.Tabs(elem_id="mode_img2img"):
+ img2img_selected_tab = gr.State(0)
+
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
+ add_copy_image_controls('img2img', init_img)
+
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
+ add_copy_image_controls('sketch', sketch)
+
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
+ add_copy_image_controls('inpaint', init_img_with_mask)
+
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
+ inpaint_color_sketch_orig = gr.State(None)
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
+
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
+
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
+
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
+ hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
+ gr.HTML(
+ "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
+ "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
+ f"{hidden}</p>"
+ )
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ with gr.Accordion("PNG info", open=False):
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
+
+ img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
+
+ for i, tab in enumerate(img2img_tabs):
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
+
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
+
+ return img
+
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js=f"switch_to_{name.replace(' ', '_')}",
+ inputs=[],
+ outputs=[],
+ )
+
+ with FormRow():
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
@@ -661,8 +613,8 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
- detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
+ detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
@@ -746,20 +698,20 @@ def create_ui():
with gr.Column(scale=4):
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
- def select_img2img_tab(tab):
- return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
-
- for i, elem in enumerate(img2img_tabs):
- elem.select(
- fn=lambda tab=i: select_img2img_tab(tab),
- inputs=[],
- outputs=[inpaint_controls, mask_alpha],
- )
-
if category not in {"accordions"}:
scripts.scripts_img2img.setup_ui_for_section(category)
- img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate(img2img_tabs):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
@@ -1286,7 +1238,7 @@ def create_ui():
loadsave.setup_ui()
- if os.path.exists(os.path.join(script_path, "notification.mp3")):
+ if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html")
@@ -1338,7 +1290,6 @@ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
def setup_ui_api(app):
from pydantic import BaseModel, Field
- from typing import List
class QuicksettingsHint(BaseModel):
name: str = Field(title="Name of the quicksettings field")
@@ -1347,7 +1298,7 @@ def setup_ui_api(app):
def quicksettings_hint():
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
- app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
+ app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 84a7d7f2..032ec4af 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -104,7 +104,7 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-def create_output_panel(tabname, outdir):
+def create_output_panel(tabname, outdir, toprow=None):
def open_folder(f):
if not os.path.exists(f):
@@ -130,12 +130,15 @@ Requested path was: {f}
else:
sp.Popen(["xdg-open", path])
- with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
- with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
+ with gr.Column(elem_id=f"{tabname}_results"):
+ if toprow:
+ toprow.create_inline_toprow_image()
- generation_info = None
- with gr.Column():
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
+
+ generation_info = None
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 2e8c1d6d..c0a73b57 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -197,7 +197,7 @@ def update_config_states_table(state_name):
config_state = config_states.all_config_states[state_name]
config_name = config_state.get("name", "Config")
- created_date = time.asctime(time.gmtime(config_state["created_at"]))
+ created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
filepath = config_state.get("filepath", "<unknown>")
try:
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 063bd7b8..f03e2033 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,3 +1,4 @@
+import functools
import os.path
import urllib.parse
from pathlib import Path
@@ -15,6 +16,17 @@ from modules.ui_components import ToolButton
extra_pages = []
allowed_dirs = set()
+default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
+
+
+@functools.cache
+def allowed_preview_extensions_with_extra(extra_extensions=None):
+ return set(default_allowed_preview_extensions) | set(extra_extensions or [])
+
+
+def allowed_preview_extensions():
+ return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
+
def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
@@ -33,9 +45,9 @@ def fetch_file(filename: str = ""):
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
- ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
+ ext = os.path.splitext(filename)[1].lower()[1:]
+ if ext not in allowed_preview_extensions():
+ raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -91,6 +103,7 @@ class ExtraNetworksPage:
self.name = title.lower()
self.id_page = self.name.replace(" ", "_")
self.card_page = shared.html("extra-networks-card.html")
+ self.allow_prompt = True
self.allow_negative_prompt = False
self.metadata = {}
self.items = {}
@@ -213,9 +226,9 @@ class ExtraNetworksPage:
metadata_button = ""
metadata = item.get("metadata")
if metadata:
- metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
- edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
local_path = ""
filename = item.get("filename", "")
@@ -235,7 +248,7 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never":
return ""
- sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
+ sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
args = {
"background_image": background_image,
@@ -266,6 +279,7 @@ class ExtraNetworksPage:
"date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0),
"name": pth.name.lower(),
+ "path": str(pth.parent).lower(),
}
def find_preview(self, path):
@@ -273,11 +287,7 @@ class ExtraNetworksPage:
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
- preview_extensions = ["png", "jpg", "jpeg", "webp"]
- if shared.opts.samples_format not in preview_extensions:
- preview_extensions.append(shared.opts.samples_format)
-
- potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
+ potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
for file in potential_files:
if os.path.isfile(file):
@@ -359,7 +369,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs = []
for page in ui.stored_extra_pages:
- with gr.Tab(page.title, id=page.id_page) as tab:
+ with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
@@ -373,19 +383,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs.append(tab)
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
- dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
- button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
+ dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
+ button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+ tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
+
for tab in unrelated_tabs:
- tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
+ tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
+
+ for page, tab in zip(ui.stored_extra_pages, related_tabs):
+ allow_prompt = "true" if page.allow_prompt else "false"
+ allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
+
+ jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
+
+ tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
- for tab in related_tabs:
- tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
+ dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
def pages_html():
if not ui.pages_contents:
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index ca6c2607..1693e71f 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -10,11 +10,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Checkpoints')
+ self.allow_prompt = False
+
def refresh(self):
shared.refresh_checkpoints()
def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
+ if checkpoint is None:
+ return
+
path, ext = os.path.splitext(checkpoint.filename)
return {
"name": checkpoint.name_for_extra,
@@ -30,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
+ # instantiate a list to protect against concurrent modification
names = list(sd_models.checkpoints_list)
for index, name in enumerate(names):
- yield self.create_item(name, index)
+ item = self.create_item(name, index)
+ if item is not None:
+ yield item
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 4cedf085..c96c4fa3 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
shared.reload_hypernetworks()
def create_item(self, name, index=None, enable_filter=True):
- full_path = shared.hypernetworks[name]
+ full_path = shared.hypernetworks.get(name)
+ if full_path is None:
+ return
+
path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
@@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
- for index, name in enumerate(shared.hypernetworks):
- yield self.create_item(name, index)
+ # instantiate a list to protect against concurrent modification
+ names = list(shared.hypernetworks)
+ for index, name in enumerate(names):
+ item = self.create_item(name, index)
+ if item is not None:
+ yield item
def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir]
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 55ef0ea7..1b334fda 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
+ if embedding is None:
+ return
path, ext = os.path.splitext(embedding.filename)
return {
@@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
- for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
- yield self.create_item(name, index)
+ # instantiate a list to protect against concurrent modification
+ names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
+ for index, name in enumerate(names):
+ item = self.create_item(name, index)
+ if item is not None:
+ yield item
def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
index b824b113..0d368f8b 100644
--- a/modules/ui_gradio_extensions.py
+++ b/modules/ui_gradio_extensions.py
@@ -2,12 +2,12 @@ import os
import gradio as gr
from modules import localization, shared, scripts
-from modules.paths import script_path, data_path
+from modules.paths import script_path, data_path, cwd
def webpath(fn):
- if fn.startswith(script_path):
- web_path = os.path.relpath(fn, script_path).replace('\\', '/')
+ if fn.startswith(cwd):
+ web_path = os.path.relpath(fn, cwd)
else:
web_path = os.path.abspath(fn)
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
index ec8fa8e8..eb20ff25 100644
--- a/modules/ui_loadsave.py
+++ b/modules/ui_loadsave.py
@@ -4,7 +4,7 @@ import os
import gradio as gr
from modules import errors
-from modules.ui_components import ToolButton
+from modules.ui_components import ToolButton, InputAccordion
def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
@@ -32,8 +32,6 @@ class UiLoadsave:
self.error_loading = True
errors.display(e, "loading settings")
-
-
def add_component(self, path, x):
"""adds component to the registry of tracked components"""
@@ -43,20 +41,24 @@ class UiLoadsave:
key = f"{path}/{field}"
if getattr(obj, 'custom_script_source', None) is not None:
- key = f"customscript/{obj.custom_script_source}/{key}"
+ key = f"customscript/{obj.custom_script_source}/{key}"
if getattr(obj, 'do_not_save_to_config', False):
return
saved_value = self.ui_settings.get(key, None)
+
+ if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value':
+ field = 'open'
+
if saved_value is None:
self.ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value):
pass
else:
- if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
+ if isinstance(obj, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
saved_value = str(saved_value)
- elif isinstance(x, gr.Number) and field == 'value':
+ elif isinstance(obj, gr.Number) and field == 'value':
try:
saved_value = float(saved_value)
except ValueError:
@@ -67,7 +69,7 @@ class UiLoadsave:
init_field(saved_value)
if field == 'value' and key not in self.component_mapping:
- self.component_mapping[key] = x
+ self.component_mapping[key] = obj
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
apply_field(x, 'visible')
@@ -100,6 +102,12 @@ class UiLoadsave:
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+ if type(x) == InputAccordion:
+ if x.accordion.visible:
+ apply_field(x.accordion, 'visible')
+ apply_field(x, 'value')
+ apply_field(x.accordion, 'value')
+
def check_tab_id(tab_id):
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
if type(tab_id) == str:
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
index d6f8d4c7..0d74c23f 100644
--- a/modules/ui_prompt_styles.py
+++ b/modules/ui_prompt_styles.py
@@ -4,6 +4,7 @@ from modules import shared, ui_common, ui_components, styles
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
styles_materialize_symbol = '\U0001f4cb' # 📋
+styles_copy_symbol = '\U0001f4dd' # 📝
def select_style(name):
@@ -52,6 +53,8 @@ def refresh_styles():
class UiPromptStyles:
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
self.tabname = tabname
+ self.main_ui_prompt = main_ui_prompt
+ self.main_ui_negative_prompt = main_ui_negative_prompt
with gr.Row(elem_id=f"{tabname}_styles_row"):
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
@@ -61,7 +64,8 @@ class UiPromptStyles:
with gr.Row():
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
- self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+ self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+ self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
with gr.Row():
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
@@ -96,15 +100,21 @@ class UiPromptStyles:
show_progress=False,
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
- self.materialize.click(
- fn=materialize_styles,
- inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
- outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ self.setup_apply_button(self.materialize)
+
+ self.copy.click(
+ fn=lambda p, n: (p, n),
+ inputs=[main_ui_prompt, main_ui_negative_prompt],
+ outputs=[self.prompt, self.neg_prompt],
show_progress=False,
- ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+ )
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
-
-
-
+ def setup_apply_button(self, button):
+ button.click(
+ fn=materialize_styles,
+ inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
+ outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
+ show_progress=False,
+ ).then(fn=None, _js="function(){update_"+self.tabname+"_tokens(); closePopup();}", show_progress=False)
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 8ff9c074..e054d00a 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -1,10 +1,11 @@
import gradio as gr
-from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
from modules.call_queue import wrap_gradio_call
from modules.shared import opts
from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript
+from concurrent.futures import ThreadPoolExecutor, as_completed
def get_value_for_setting(key):
@@ -63,6 +64,9 @@ class UiSettings:
quicksettings_list = None
quicksettings_names = None
text_settings = None
+ show_all_pages = None
+ show_one_page = None
+ search_input = None
def run_settings(self, *args):
changed = []
@@ -135,7 +139,7 @@ class UiSettings:
gr.Group()
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
current_tab.__enter__()
- current_row = gr.Column(variant='compact')
+ current_row = gr.Column(elem_id=f"column_settings_{elem_id}", variant='compact')
current_row.__enter__()
previous_section = item.section
@@ -173,26 +177,43 @@ class UiSettings:
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
with gr.Row():
- unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
- reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
+ with gr.Row():
+ calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
+ calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
- gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+ self.show_all_pages = gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+ self.show_one_page = gr.Button(value="Show only one page", elem_id="settings_show_one_page", visible=False)
+ self.show_one_page.click(lambda: None)
+
+ self.search_input = gr.Textbox(value="", elem_id="settings_search", max_lines=1, placeholder="Search...", show_label=False)
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+ def call_func_and_return_text(func, text):
+ def handler():
+ t = timer.Timer()
+ func()
+ t.record(text)
+
+ return f'{text} in {t.total:.1f}s'
+
+ return handler
+
unload_sd_model.click(
- fn=sd_models.unload_model_weights,
+ fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
inputs=[],
- outputs=[]
+ outputs=[self.result]
)
reload_sd_model.click(
- fn=sd_models.reload_model_weights,
+ fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
inputs=[],
- outputs=[]
+ outputs=[self.result]
)
request_notifications.click(
@@ -241,6 +262,21 @@ class UiSettings:
outputs=[sysinfo_check_output],
)
+ def calculate_all_checkpoint_hash_fn(max_thread):
+ checkpoints_list = sd_models.checkpoints_list.values()
+ with ThreadPoolExecutor(max_workers=max_thread) as executor:
+ futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list]
+ completed = 0
+ for _ in as_completed(futures):
+ completed += 1
+ print(f"{completed} / {len(checkpoints_list)} ")
+ print("Finish calculating hash for all checkpoints")
+
+ calculate_all_checkpoint_hash.click(
+ fn=calculate_all_checkpoint_hash_fn,
+ inputs=[calculate_all_checkpoint_hash_threads],
+ )
+
self.interface = settings_interface
def add_quicksettings(self):
@@ -294,3 +330,8 @@ class UiSettings:
outputs=[self.component_dict[k] for k in component_keys],
queue=False,
)
+
+ def search(self, text):
+ print(text)
+
+ return [gr.update(visible=text in (comp.label or "")) for comp in self.components]
diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py
new file mode 100644
index 00000000..985b5a2d
--- /dev/null
+++ b/modules/ui_toprow.py
@@ -0,0 +1,141 @@
+import gradio as gr
+
+from modules import shared, ui_prompt_styles
+import modules.images
+
+from modules.ui_components import ToolButton
+
+
+class Toprow:
+ """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
+
+ prompt = None
+ prompt_img = None
+ negative_prompt = None
+
+ button_interrogate = None
+ button_deepbooru = None
+
+ interrupt = None
+ skip = None
+ submit = None
+
+ paste = None
+ clear_prompt_button = None
+ apply_styles = None
+ restore_progress_button = None
+
+ token_counter = None
+ token_button = None
+ negative_token_counter = None
+ negative_token_button = None
+
+ ui_styles = None
+
+ submit_box = None
+
+ def __init__(self, is_img2img, is_compact=False):
+ id_part = "img2img" if is_img2img else "txt2img"
+ self.id_part = id_part
+ self.is_img2img = is_img2img
+ self.is_compact = is_compact
+
+ if not is_compact:
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ self.create_classic_toprow()
+ else:
+ self.create_submit_box()
+
+ def create_classic_toprow(self):
+ self.create_prompts()
+
+ with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"):
+ self.create_submit_box()
+
+ self.create_tools_row()
+
+ self.create_styles_ui()
+
+ def create_inline_toprow_prompts(self):
+ if not self.is_compact:
+ return
+
+ self.create_prompts()
+
+ with gr.Row(elem_classes=["toprow-compact-stylerow"]):
+ with gr.Column(elem_classes=["toprow-compact-tools"]):
+ self.create_tools_row()
+ with gr.Column():
+ self.create_styles_ui()
+
+ def create_inline_toprow_image(self):
+ if not self.is_compact:
+ return
+
+ self.submit_box.render()
+
+ def create_prompts(self):
+ with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
+ with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
+
+ with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+
+ self.prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[self.prompt_img],
+ outputs=[self.prompt, self.prompt_img],
+ show_progress=False,
+ )
+
+ def create_submit_box(self):
+ with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
+ self.submit_box = submit_box
+
+ self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
+ self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
+ self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
+
+ self.skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ self.interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ def create_tools_row(self):
+ with gr.Row(elem_id=f"{self.id_part}_tools"):
+ from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol
+
+ self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
+ self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt")
+ self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
+
+ if self.is_img2img:
+ self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate")
+ self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru")
+
+ self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")
+
+ self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"])
+ self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
+ self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"])
+ self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")
+
+ self.clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[self.prompt, self.negative_prompt],
+ outputs=[self.prompt, self.negative_prompt],
+ )
+
+ def create_styles_ui(self):
+ self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)
+ self.ui_styles.setup_apply_button(self.apply_styles)
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
new file mode 100644
index 00000000..a727e865
--- /dev/null
+++ b/modules/xlmr_m18.py
@@ -0,0 +1,164 @@
+from transformers import BertPreTrainedModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ config_class = BertSeriesConfig
+
+ def __init__(self, config=None, **kargs):
+ # modify initialization for autoloading
+ if config is None:
+ config = XLMRobertaConfig()
+ config.attention_probs_dropout_prob= 0.1
+ config.bos_token_id=0
+ config.eos_token_id=2
+ config.hidden_act='gelu'
+ config.hidden_dropout_prob=0.1
+ config.hidden_size=1024
+ config.initializer_range=0.02
+ config.intermediate_size=4096
+ config.layer_norm_eps=1e-05
+ config.max_position_embeddings=514
+
+ config.num_attention_heads=16
+ config.num_hidden_layers=24
+ config.output_past=True
+ config.pad_token_id=1
+ config.position_embedding_type= "absolute"
+
+ config.type_vocab_size= 1
+ config.use_cache=True
+ config.vocab_size= 250002
+ config.project_dim = 1024
+ config.learn_encoder = False
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+ # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+ # self.pooler = lambda x: x[:,0]
+ # self.post_init()
+
+ self.has_pre_transformation = True
+ if self.has_pre_transformation:
+ self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
+ self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.post_init()
+
+ def encode(self,c):
+ device = next(self.parameters()).device
+ text = self.tokenizer(c,
+ truncation=True,
+ max_length=77,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+ text["attention_mask"] = torch.tensor(
+ text['attention_mask']).to(device)
+ features = self(**text)
+ return features['projection_state']
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) :
+ r"""
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ outputs = self.roberta(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ # # last module outputs
+ # sequence_output = outputs[0]
+
+
+ # # project every module
+ # sequence_output_ln = self.pre_LN(sequence_output)
+
+ # # pooler
+ # pooler_output = self.pooler(sequence_output_ln)
+ # pooler_output = self.transformation(pooler_output)
+ # projection_state = self.transformation(outputs.last_hidden_state)
+
+ if self.has_pre_transformation:
+ sequence_output2 = outputs["hidden_states"][-2]
+ sequence_output2 = self.pre_LN(sequence_output2)
+ projection_state2 = self.transformation_pre(sequence_output2)
+
+ return {
+ "projection_state": projection_state2,
+ "last_hidden_state": outputs.last_hidden_state,
+ "hidden_states": outputs.hidden_states,
+ "attentions": outputs.attentions,
+ }
+ else:
+ projection_state = self.transformation(outputs.last_hidden_state)
+ return {
+ "projection_state": projection_state,
+ "last_hidden_state": outputs.last_hidden_state,
+ "hidden_states": outputs.hidden_states,
+ "attentions": outputs.attentions,
+ }
+
+
+ # return {
+ # 'pooler_output':pooler_output,
+ # 'last_hidden_state':outputs.last_hidden_state,
+ # 'hidden_states':outputs.hidden_states,
+ # 'attentions':outputs.attentions,
+ # 'projection_state':projection_state,
+ # 'sequence_out': sequence_output
+ # }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+ base_model_prefix = 'roberta'
+ config_class= RobertaSeriesConfig
diff --git a/requirements_versions.txt b/requirements_versions.txt
index e84bd427..cb7403a9 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -27,6 +27,6 @@ timm==0.9.2
tomesd==0.1.3
torch
torchdiffeq==0.2.3
-torchsde==0.2.5
+torchsde==0.2.6
transformers==4.30.2
httpx==0.24.1
diff --git a/script.js b/script.js
index 34cca765..c0e678ea 100644
--- a/script.js
+++ b/script.js
@@ -124,16 +124,29 @@ document.addEventListener("DOMContentLoaded", function() {
* Add a ctrl+enter as a shortcut to start a generation
*/
document.addEventListener('keydown', function(e) {
- var handled = false;
- if (e.key !== undefined) {
- if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
- } else if (e.keyCode !== undefined) {
- if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
- }
- if (handled) {
- var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
- if (button) {
- button.click();
+ const isEnter = e.key === 'Enter' || e.keyCode === 13;
+ const isModifierKey = e.metaKey || e.ctrlKey || e.altKey;
+
+ const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
+ const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
+
+ if (isEnter && isModifierKey) {
+ if (interruptButton.style.display === 'block') {
+ interruptButton.click();
+ const callback = (mutationList) => {
+ for (const mutation of mutationList) {
+ if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
+ if (interruptButton.style.display === 'none') {
+ generateButton.click();
+ observer.disconnect();
+ }
+ }
+ }
+ };
+ const observer = new MutationObserver(callback);
+ observer.observe(interruptButton, {attributes: true});
+ } else {
+ generateButton.click();
}
e.preventDefault();
}
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index edb70ac0..eb42a29e 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -29,7 +29,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
+ upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow():
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index 50320d55..a4a2f24d 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -5,11 +5,17 @@ import shlex
import modules.scripts as scripts
import gradio as gr
-from modules import sd_samplers, errors
+from modules import sd_samplers, errors, sd_models
from modules.processing import Processed, process_images
from modules.shared import state
+def process_model_tag(tag):
+ info = sd_models.get_closet_checkpoint_match(tag)
+ assert info is not None, f'Unknown checkpoint: {tag}'
+ return info.name
+
+
def process_string_tag(tag):
return tag
@@ -27,7 +33,7 @@ def process_boolean_tag(tag):
prompt_tags = {
- "sd_model": None,
+ "sd_model": process_model_tag,
"outpath_samples": process_string_tag,
"outpath_grids": process_string_tag,
"prompt_for_display": process_string_tag,
@@ -108,6 +114,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
+ prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
@@ -118,9 +125,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
- return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
+ return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
- def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
+ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
p.do_not_save_grid = True
@@ -156,7 +163,22 @@ class Script(scripts.Script):
copy_p = copy.copy(p)
for k, v in args.items():
- setattr(copy_p, k, v)
+ if k == "sd_model":
+ copy_p.override_settings['sd_model_checkpoint'] = v
+ else:
+ setattr(copy_p, k, v)
+
+ if args.get("prompt") and p.prompt:
+ if prompt_position == "start":
+ copy_p.prompt = args.get("prompt") + " " + p.prompt
+ else:
+ copy_p.prompt = p.prompt + " " + args.get("prompt")
+
+ if args.get("negative_prompt") and p.negative_prompt:
+ if prompt_position == "start":
+ copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt
+ else:
+ copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt")
proc = process_images(copy_p)
images += proc.images
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 939d8605..0dc255bc 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -205,13 +205,14 @@ def csv_string_to_list_strip(data_str):
class AxisOption:
- def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
+ def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):
self.label = label
self.type = type
self.apply = apply
self.format_value = format_value
self.confirm = confirm
self.cost = cost
+ self.prepare = prepare
self.choices = choices
@@ -536,6 +537,8 @@ class Script(scripts.Script):
if opt.choices is not None and not csv_mode:
valslist = vals_dropdown
+ elif opt.prepare is not None:
+ valslist = opt.prepare(vals)
else:
valslist = csv_string_to_list_strip(vals)
@@ -773,6 +776,8 @@ class Script(scripts.Script):
# TODO: See previous comment about intentional data misalignment.
adj_g = g-1 if g > 0 else g
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
+ if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid
+ break
if not include_sub_grids:
# Done with sub-grids, drop all related information:
diff --git a/style.css b/style.css
index fb4e2f1f..73162022 100644
--- a/style.css
+++ b/style.css
@@ -83,8 +83,10 @@ div.compact{
white-space: nowrap;
}
-.gradio-dropdown ul.options li.item {
- padding: 0.05em 0;
+@media (pointer:fine) {
+ .gradio-dropdown ul.options li.item {
+ padding: 0.05em 0;
+ }
}
.gradio-dropdown ul.options li.item.selected {
@@ -202,6 +204,11 @@ div.block.gradio-accordion {
padding: 8px 8px;
}
+input[type="checkbox"].input-accordion-checkbox{
+ vertical-align: sub;
+ margin-right: 0.5em;
+}
+
/* txt2img/img2img specific */
@@ -289,6 +296,13 @@ div.block.gradio-accordion {
min-height: 4.5em;
}
+#txt2img_generate, #img2img_generate {
+ min-height: 4.5em;
+}
+.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate {
+ min-height: 3em;
+}
+
@media screen and (min-width: 2500px) {
#txt2img_gallery, #img2img_gallery {
min-height: 768px;
@@ -396,6 +410,15 @@ div#extras_scale_to_tab div.form{
min-width: 0.5em;
}
+div.toprow-compact-stylerow{
+ margin: 0.5em 0;
+}
+
+div.toprow-compact-tools{
+ min-width: fit-content !important;
+ max-width: fit-content;
+}
+
/* settings */
#quicksettings {
align-items: end;
@@ -421,6 +444,7 @@ div#extras_scale_to_tab div.form{
#settings > div{
border: none;
margin-left: 10em;
+ padding: 0 var(--spacing-xl);
}
#settings > div.tab-nav{
@@ -435,6 +459,7 @@ div#extras_scale_to_tab div.form{
border: none;
text-align: left;
white-space: initial;
+ padding: 4px;
}
#settings_result{
@@ -516,7 +541,8 @@ table.popup-table .link{
height: 20px;
background: #b4c0cc;
border-radius: 3px !important;
- top: -20px;
+ top: -14px;
+ left: 0px;
width: 100%;
}
@@ -581,7 +607,6 @@ table.popup-table .link{
width: 100%;
height: 100%;
overflow: auto;
- background-color: rgba(20, 20, 20, 0.95);
}
.global-popup *{
@@ -590,9 +615,6 @@ table.popup-table .link{
.global-popup-close:before {
content: "×";
-}
-
-.global-popup-close{
position: fixed;
right: 0.25em;
top: 0;
@@ -601,10 +623,20 @@ table.popup-table .link{
font-size: 32pt;
}
+.global-popup-close{
+ position: fixed;
+ left: 0;
+ top: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(20, 20, 20, 0.95);
+}
+
.global-popup-inner{
display: inline-block;
margin: auto;
padding: 2em;
+ z-index: 1001;
}
/* fullpage image viewer */
@@ -808,6 +840,10 @@ footer {
/* extra networks UI */
+.extra-page .prompt{
+ margin: 0 0 0.5em 0;
+}
+
.extra-network-cards{
height: calc(100vh - 24rem);
overflow: clip scroll;
diff --git a/webui.bat b/webui.bat
index 42e7d517..e2c9079d 100644
--- a/webui.bat
+++ b/webui.bat
@@ -1,6 +1,11 @@
@echo off
+if exist webui.settings.bat (
+ call webui.settings.bat
+)
+
if not defined PYTHON (set PYTHON=python)
+if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%")
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
set SD_WEBUI_RESTART=tmp/restart
diff --git a/webui.py b/webui.py
index 12328423..9ed20b30 100644
--- a/webui.py
+++ b/webui.py
@@ -74,7 +74,7 @@ def webui():
if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
auto_launch_browser = True
elif shared.opts.auto_launch_browser == "Local":
- auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok, cmd_opts.server_name])
+ auto_launch_browser = not cmd_opts.webui_is_non_local
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
diff --git a/webui.sh b/webui.sh
index 3d0f87ee..08911469 100755
--- a/webui.sh
+++ b/webui.sh
@@ -4,12 +4,6 @@
# change the variables in webui-user.sh instead #
#################################################
-
-use_venv=1
-if [[ $venv_dir == "-" ]]; then
- use_venv=0
-fi
-
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
@@ -28,6 +22,12 @@ then
source "$SCRIPT_DIR"/webui-user.sh
fi
+# If $venv_dir is "-", then disable venv support
+use_venv=1
+if [[ $venv_dir == "-" ]]; then
+ use_venv=0
+fi
+
# Set defaults
# Install directory without trailing slash
if [[ -z "${install_dir}" ]]
@@ -51,6 +51,8 @@ fi
if [[ -z "${GIT}" ]]
then
export GIT="git"
+else
+ export GIT_PYTHON_GIT_EXECUTABLE="${GIT}"
fi
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
@@ -141,9 +143,8 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
- export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6"
- # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5)
- # so switch to nightly rocm5.6 without explicit versions this time
+ export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6"
+ # Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"