aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/on_pull_request.yaml2
-rw-r--r--.github/workflows/run_tests.yaml4
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py8
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py20
-rw-r--r--extensions-builtin/Lora/lora.py2
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py30
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py62
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js30
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py1
-rw-r--r--javascript/edit-attention.js5
-rw-r--r--javascript/edit-order.js41
-rw-r--r--javascript/extensions.js18
-rw-r--r--modules/api/api.py128
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/call_queue.py5
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/codeformer_model.py6
-rw-r--r--modules/devices.py10
-rw-r--r--modules/esrgan_model.py23
-rw-r--r--modules/extra_networks.py3
-rw-r--r--modules/extras.py3
-rw-r--r--modules/generation_parameters_copypaste.py29
-rw-r--r--modules/gfpgan_model.py2
-rw-r--r--modules/hypernetworks/hypernetwork.py24
-rw-r--r--modules/images.py59
-rw-r--r--modules/img2img.py47
-rw-r--r--modules/interrogate.py3
-rw-r--r--modules/launch_utils.py6
-rw-r--r--modules/mac_specific.py23
-rw-r--r--modules/modelloader.py31
-rw-r--r--modules/paths.py14
-rw-r--r--modules/postprocessing.py7
-rw-r--r--modules/processing.py21
-rw-r--r--modules/realesrgan_model.py33
-rw-r--r--modules/scripts.py40
-rw-r--r--modules/sd_models.py8
-rw-r--r--modules/shared.py35
-rw-r--r--modules/textual_inversion/logging.py48
-rw-r--r--modules/textual_inversion/preprocess.py2
-rw-r--r--modules/txt2img.py6
-rw-r--r--modules/ui.py13
-rw-r--r--modules/ui_extensions.py29
-rw-r--r--modules/ui_extra_networks.py8
-rw-r--r--style.css17
-rw-r--r--webui.py34
-rwxr-xr-xwebui.sh12
46 files changed, 606 insertions, 352 deletions
diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml
index 7b7219fd..8ebf5918 100644
--- a/.github/workflows/on_pull_request.yaml
+++ b/.github/workflows/on_pull_request.yaml
@@ -18,7 +18,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies.
- name: Install Ruff
- run: pip install ruff==0.0.265
+ run: pip install ruff==0.0.272
- name: Run Ruff
run: ruff .
lint-js:
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 226cf759..178c026a 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -42,7 +42,7 @@ jobs:
--no-half
--disable-opt-split-attention
--use-cpu all
- --add-stop-route
+ --api-server-stop
2>&1 | tee output.txt &
- name: Run tests
run: |
@@ -50,7 +50,7 @@ jobs:
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server
if: always()
- run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10
+ run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
- name: Show coverage
run: |
python -m coverage combine .coverage*
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 7f450086..7cac36ce 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -12,7 +12,7 @@ import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
-from modules import shared, sd_hijack
+from modules import shared, sd_hijack, devices
cached_ldsr_model: torch.nn.Module = None
@@ -112,8 +112,7 @@ class LDSR:
gc.collect()
- if torch.cuda.is_available:
- torch.cuda.empty_cache()
+ devices.torch_gc()
im_og = image
width_og, height_og = im_og.size
@@ -150,8 +149,7 @@ class LDSR:
del model
gc.collect()
- if torch.cuda.is_available:
- torch.cuda.empty_cache()
+ devices.torch_gc()
return a
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index dbd6d331..bd78dece 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -1,7 +1,6 @@
import os
-from basicsr.utils.download_util import load_file_from_url
-
+from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
@@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
- model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
+ model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
- yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
+ yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
- try:
- return LDSR(model, yaml)
- except Exception:
- errors.report("Error importing LDSR", exc_info=True)
- return None
+ return LDSR(model, yaml)
def do_upscale(self, img, path):
- ldsr = self.load_model(path)
- if ldsr is None:
- print("NO LDSR!")
+ try:
+ ldsr = self.load_model(path)
+ except Exception:
+ errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 34ff57dd..cd46e6c7 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -443,7 +443,7 @@ def list_available_loras():
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
- for filename in sorted(candidates, key=str.lower):
+ for filename in candidates:
if os.path.isdir(filename):
continue
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 85b4505f..167d2f64 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,4 +1,3 @@
-import os.path
import sys
import PIL.Image
@@ -6,12 +5,11 @@ import numpy as np
import torch
from tqdm import tqdm
-from basicsr.utils.download_util import load_file_from_url
-
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
-from scunet_model_arch import SCUNet as net
+from scunet_model_arch import SCUNet
+from modules.modelloader import load_file_from_url
from modules.shared import opts
@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = []
add_model2 = True
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -87,11 +85,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def do_upscale(self, img: PIL.Image.Image, selected_file):
- torch.cuda.empty_cache()
+ devices.torch_gc()
- model = self.load_model(selected_file)
- if model is None:
- print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
+ try:
+ model = self.load_model(selected_file)
+ except Exception as e:
+ print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
@@ -111,7 +110,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
- torch.cuda.empty_cache()
+ devices.torch_gc()
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
@@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str):
device = devices.get_device_for('scunet')
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
- print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
- return None
-
- model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+ model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 1c7bf325..c2c2a43c 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,17 +1,17 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
-from swinir_model_arch import SwinIR as net
-from swinir_model_arch_v2 import Swin2SR as net2
+from swinir_model_arch import SwinIR
+from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData
+SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir')
@@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
self.name = "SwinIR"
- self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
- "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
- "-L_x4_GAN.pth "
+ self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x"
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
- if "http" in model:
+ if model.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(model)
@@ -37,42 +35,42 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers
def do_upscale(self, img, model_file):
- model = self.load_model(model_file)
- if model is None:
+ try:
+ model = self.load_model(model_file)
+ except Exception as e:
+ print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model)
- try:
- torch.cuda.empty_cache()
- except Exception:
- pass
+ devices.torch_gc()
return img
def load_model(self, path, scale=4):
- if "http" in path:
- dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
+ if path.startswith("http"):
+ filename = modelloader.load_file_from_url(
+ url=path,
+ model_dir=self.model_download_path,
+ file_name=f"{self.model_name.replace(' ', '_')}.pth",
+ )
else:
filename = path
- if filename is None or not os.path.exists(filename):
- return None
if filename.endswith(".v2.pth"):
- model = net2(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
+ model = Swin2SR(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="1conv",
)
params = None
else:
- model = net(
+ model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
index 5ebd2073..30199dcd 100644
--- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
+++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
@@ -200,7 +200,8 @@ onUiLoaded(async() => {
canvas_hotkey_move: "KeyF",
canvas_hotkey_overlap: "KeyO",
canvas_disabled_functions: [],
- canvas_show_tooltip: true
+ canvas_show_tooltip: true,
+ canvas_blur_prompt: false
};
const functionMap = {
@@ -608,6 +609,19 @@ onUiLoaded(async() => {
// Handle keydown events
function handleKeyDown(event) {
+ // Disable key locks to make pasting from the buffer work correctly
+ if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
+ return;
+ }
+
+ // before activating shortcut, ensure user is not actively typing in an input field
+ if (!hotkeysConfig.canvas_blur_prompt) {
+ if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
+ return;
+ }
+ }
+
+
const hotkeyActions = {
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
@@ -686,6 +700,20 @@ onUiLoaded(async() => {
// Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
function handleMoveKeyDown(e) {
+
+ // Disable key locks to make pasting from the buffer work correctly
+ if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
+ return;
+ }
+
+ // before activating shortcut, ensure user is not actively typing in an input field
+ if (!hotkeysConfig.canvas_blur_prompt) {
+ if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
+ return;
+ }
+ }
+
+
if (e.code === hotkeysConfig.canvas_hotkey_move) {
if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
e.preventDefault();
diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
index 1b6683aa..380176ce 100644
--- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
+++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
@@ -9,5 +9,6 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
+ "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
}))
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index ffa73147..8906c892 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -100,11 +100,12 @@ function keyupEditAttention(event) {
if (String(weight).length == 1) weight += ".0";
if (closeCharacter == ')' && weight == 1) {
- text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
+ var endParenPos = text.substring(selectionEnd).indexOf(')');
+ text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
selectionStart--;
selectionEnd--;
} else {
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
}
target.focus();
diff --git a/javascript/edit-order.js b/javascript/edit-order.js
new file mode 100644
index 00000000..ad983d33
--- /dev/null
+++ b/javascript/edit-order.js
@@ -0,0 +1,41 @@
+/* alt+left/right moves text in prompt */
+
+function keyupEditOrder(event) {
+ if (!opts.keyedit_move) return;
+
+ let target = event.originalTarget || event.composedPath()[0];
+ if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
+ if (!event.altKey) return;
+ event.preventDefault()
+
+ let isLeft = event.key == "ArrowLeft";
+ let isRight = event.key == "ArrowRight";
+ if (!isLeft && !isRight) return;
+
+ let selectionStart = target.selectionStart;
+ let selectionEnd = target.selectionEnd;
+ let text = target.value;
+ let items = text.split(",");
+ let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;
+ let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;
+ let range = indexEnd - indexStart + 1;
+
+ if (isLeft && indexStart > 0) {
+ items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));
+ target.value = items.join();
+ target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);
+ target.selectionEnd = items.slice(0, indexEnd).join().length;
+ } else if (isRight && indexEnd < items.length - 1) {
+ items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));
+ target.value = items.join();
+ target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;
+ target.selectionEnd = items.slice(0, indexEnd + 2).join().length;
+ }
+
+ event.preventDefault();
+ updateInput(target);
+}
+
+addEventListener('keydown', (event) => {
+ keyupEditOrder(event);
+});
diff --git a/javascript/extensions.js b/javascript/extensions.js
index efeaf3a5..1f7254c5 100644
--- a/javascript/extensions.js
+++ b/javascript/extensions.js
@@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
}
return [confirmed, config_state_name, config_restore_type];
}
+
+function toggle_all_extensions(event) {
+ gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
+ checkbox_el.checked = event.target.checked;
+ });
+}
+
+function toggle_extension() {
+ let all_extensions_toggled = true;
+ for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
+ if (!checkbox_el.checked) {
+ all_extensions_toggled = false;
+ break;
+ }
+ }
+
+ gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
+}
diff --git a/modules/api/api.py b/modules/api/api.py
index 2e49526e..224bbfc6 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases
from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
@@ -30,13 +30,7 @@ from modules import devices
from typing import Dict, List, Any
import piexif
import piexif.helper
-
-
-def upscaler_to_index(name: str):
- try:
- return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
+from contextlib import closing
def script_name_to_index(name, scripts):
@@ -84,6 +78,8 @@ def encode_pil_to_base64(image):
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+ if image.mode == "RGBA":
+ image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
@@ -209,6 +205,11 @@ class Api:
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+ if shared.cmd_opts.api_server_stop:
+ self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
+
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -324,19 +325,19 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
-
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+
+ shared.state.begin(job="scripts_txt2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -380,20 +381,20 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
- p.init_images = [decode_base64_to_image(x) for x in init_images]
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
-
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+
+ shared.state.begin(job="scripts_img2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -517,6 +518,10 @@ class Api:
return options
def set_config(self, req: Dict[str, Any]):
+ checkpoint_name = req.get("sd_model_checkpoint", None)
+ if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases:
+ raise RuntimeError(f"model {checkpoint_name!r} not found")
+
for k, v in req.items():
shared.opts.set(k, v)
@@ -597,44 +602,42 @@ class Api:
def create_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
- shared.state.end()
return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create embedding error: {e}")
+ finally:
+ shared.state.end()
+
def create_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
- shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create hypernetwork error: {e}")
+ finally:
+ shared.state.end()
def preprocess(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
- return models.PreprocessResponse(info = 'preprocess complete')
+ return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
- shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
- except AssertionError as e:
- shared.state.end()
+ except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
- except FileNotFoundError as e:
+ finally:
shared.state.end()
- return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -647,15 +650,15 @@ class Api:
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
- shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError as msg:
- shared.state.end()
+ except Exception as msg:
return models.TrainResponse(info=f"train embedding error: {msg}")
+ finally:
+ shared.state.end()
def train_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
@@ -673,9 +676,10 @@ class Api:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError:
+ except Exception as exc:
+ return models.TrainResponse(info=f"train embedding error: {exc}")
+ finally:
shared.state.end()
- return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try:
@@ -715,3 +719,15 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
+
+ def kill_webui(self):
+ restart.stop_program()
+
+ def restart_webui(self):
+ if restart.is_restartable():
+ restart.restart_program()
+ return Response(status_code=501)
+
+ def stop_webui(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
diff --git a/modules/api/models.py b/modules/api/models.py
index b3a745f0..b5683071 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
-class ArtistItem(BaseModel):
- name: str = Field(title="Name")
- score: float = Field(title="Score")
- category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 1b5e5273..3b94f8a4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,3 +1,4 @@
+from functools import wraps
import html
import threading
import time
@@ -18,6 +19,7 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
+ @wraps(func)
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
@@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
id_task = None
with queue_lock:
- shared.state.begin()
+ shared.state.begin(job=id_task)
progress.start_task(id_task)
try:
@@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+ @wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index de905caa..278a605e 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
-parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index d974e4b8..da42b5e9 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -15,7 +15,6 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-have_codeformer = False
codeformer = None
@@ -100,7 +99,7 @@ def setup_model(dirname):
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
- torch.cuda.empty_cache()
+ devices.torch_gc()
except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -123,9 +122,6 @@ def setup_model(dirname):
return restored_img
- global have_codeformer
- have_codeformer = True
-
global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
diff --git a/modules/devices.py b/modules/devices.py
index 1ed6ffdc..c5ad950f 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -15,13 +15,6 @@ def has_mps() -> bool:
else:
return mac_specific.has_mps
-def extract_device_id(args, name):
- for x in range(len(args)):
- if name in args[x]:
- return args[x + 1]
-
- return None
-
def get_cuda_device_string():
from modules import shared
@@ -56,10 +49,13 @@ def get_device_for(task):
def torch_gc():
+
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
+ elif has_mps() and hasattr(torch.mps, 'empty_cache'):
+ torch.mps.empty_cache()
def enable_tf32():
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 2fced999..02a1727d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,15 +1,13 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 1f093df2..41799b0a 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -103,6 +103,9 @@ def activate(p, extra_network_data):
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")
+ if p.scripts is not None:
+ p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
+
def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
diff --git a/modules/extras.py b/modules/extras.py
index 830b53aa..e9c0263e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
- shared.state.begin()
- shared.state.job = 'model-merge'
+ shared.state.begin(job="model-merge")
def fail(message):
shared.state.textinfo = message
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index dd30a1b5..a3448be9 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
@@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-settings_map = {}
-
-
-
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 6ecd295c..8e0f13bd 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -25,7 +25,7 @@ def gfpgann():
return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
+ if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 5d12b449..51941c11 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork)
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@@ -446,18 +435,6 @@ def statistics(data):
return total_information, recent_information
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False
pbar.close()
hypernetwork.eval()
- #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
diff --git a/modules/images.py b/modules/images.py
index 7bbfc3e0..b5412548 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import pytz
@@ -10,7 +12,7 @@ import re
import numpy as np
import piexif
import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
import string
import json
import hashlib
@@ -139,6 +141,11 @@ class GridAnnotation:
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+
+ color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
+ color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
+ color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
+
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
@@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
fnt = get_font(fontsize)
- color_active = (0, 0, 0)
- color_inactive = (153, 153, 153)
-
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
cols = im.width // width
@@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
- calc_img = Image.new("RGB", (1, 1), "white")
+ calc_img = Image.new("RGB", (1, 1), color_background)
calc_d = ImageDraw.Draw(calc_img)
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
for row in range(rows):
for col in range(cols):
@@ -372,8 +376,8 @@ class FilenameGenerator:
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
-
}
default_time_format = '%Y%m%d%H%M%S'
@@ -497,13 +501,23 @@ def get_next_sequence_number(path, basename):
return result + 1
-def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+ """
+ Saves image to filename, including geninfo as text information for generation info.
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+ """
+
if extension is None:
extension = os.path.splitext(filename)[1]
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png':
+ existing_pnginfo = existing_pnginfo or {}
+ if opts.enable_pnginfo:
+ existing_pnginfo[pnginfo_section_name] = geninfo
+
if opts.enable_pnginfo:
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in (existing_pnginfo or {}).items():
@@ -622,7 +636,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
temp_file_path = f"{filename_without_extension}.tmp"
- save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
os.replace(temp_file_path, filename_without_extension + extension)
@@ -639,12 +653,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
-
+ resize_to = None
if oversize and ratio > 1:
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
elif oversize:
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
+ if resize_to is not None:
+ try:
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+ image = image.resize(resize_to, LANCZOS)
+ except Exception:
+ image = image.resize(resize_to)
try:
_atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e:
@@ -662,8 +682,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn
-def read_info_from_image(image):
- items = image.info or {}
+IGNORED_INFO_KEYS = {
+ 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity', 'photoshop',
+}
+
+
+def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
+ items = (image.info or {}).copy()
geninfo = items.pop('parameters', None)
@@ -679,9 +706,7 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
- 'icc_profile', 'chromaticity']:
+ for field in IGNORED_INFO_KEYS:
items.pop(field, None)
if items.get("Software", None) == "NovelAI":
diff --git a/modules/img2img.py b/modules/img2img.py
index 2c497020..ef87eb0f 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -3,9 +3,10 @@ from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import gradio as gr
-from modules import sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules import sd_samplers, images as imgutil
+from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@@ -14,10 +15,10 @@ from modules.ui import plaintext_to_html
import modules.scripts
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
processing.fix_seed(p)
- images = shared.listfiles(input_dir)
+ images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
is_inpaint_batch = False
if inpaint_mask_dir:
@@ -36,6 +37,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
state.job_count = len(images) * p.n_iter
+ # extract "default" params to use in case getting png info fails
+ prompt = p.prompt
+ negative_prompt = p.negative_prompt
+ seed = p.seed
+ cfg_scale = p.cfg_scale
+ sampler_name = p.sampler_name
+ steps = p.steps
+
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
@@ -79,25 +88,45 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
+ if use_png_info:
+ try:
+ info_img = img
+ if png_info_dir:
+ info_img_path = os.path.join(png_info_dir, os.path.basename(image))
+ info_img = Image.open(info_img_path)
+ geninfo, _ = imgutil.read_info_from_image(info_img)
+ parsed_parameters = parse_generation_parameters(geninfo)
+ parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
+ except Exception:
+ parsed_parameters = {}
+
+ p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
+ p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
+ p.seed = int(parsed_parameters.get("Seed", seed))
+ p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
+ p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
+ p.steps = int(parsed_parameters.get("Steps", steps))
+
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
for n, processed_image in enumerate(proc.images):
filename = image_path.name
+ relpath = os.path.dirname(os.path.relpath(image, input_dir))
if n > 0:
left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}"
if not save_normally:
- os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
if processed_image.mode == 'RGBA':
processed_image = processed_image.convert("RGB")
- processed_image.save(os.path.join(output_dir, filename))
+ processed_image.save(os.path.join(output_dir, relpath, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -180,6 +209,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.scripts = modules.scripts.scripts_img2img
p.script_args = args
+ p.user = request.username
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
@@ -189,7 +220,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
processed = Processed(p, [], p.seed, "")
else:
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9b2c5b60..a3ae1dd5 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -184,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
+ shared.state.begin(job="interrogate")
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 609a181e..0e0dbca4 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -142,15 +142,15 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None:
return
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash:
return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index d74c6b95..735847f5 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc
from packaging import version
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
- try:
- torch.zeros(1).to(torch.device("mps"))
- return True
- except Exception:
- return False
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+ else:
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps()
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 75f01247..098bcb79 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, places[0], True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str):
- if "http" in file:
+ if file.startswith("http"):
file = urlparse(file).path
file = os.path.basename(file)
diff --git a/modules/paths.py b/modules/paths.py
index 5171df4f..bada804e 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
-
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 736315e2..136e9c88 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
- shared.state.begin()
- shared.state.job = 'extras'
+ shared.state.begin(job="extras")
image_data = []
image_names = []
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
for image, name in zip(image_data, image_names):
shared.state.textinfo = name
- existing_pnginfo = image.info or {}
+ parameters, existing_pnginfo = images.read_info_from_image(image)
+ if parameters:
+ existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
diff --git a/modules/processing.py b/modules/processing.py
index 8da73884..21d1492c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -184,6 +184,8 @@ class StableDiffusionProcessing:
self.uc = None
self.c = None
+ self.user = None
+
@property
def sd_model(self):
return shared.sd_model
@@ -549,7 +551,7 @@ def program_version():
return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -573,7 +575,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
@@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
+ "User": p.user if opts.add_user_name_to_info else None,
}
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- text = infotext()
+ text = infotext(use_main_prompt=True)
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
@@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data)
@@ -1074,6 +1078,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+ if self.scripts is not None:
+ self.scripts.before_hr(self)
+
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 2d27b321..0700b853 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -2,7 +2,6 @@ import os
import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable:
return img
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print(f"Unable to load RealESRGAN model: {info.name}")
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image
def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
- return info
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
- return None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
diff --git a/modules/scripts.py b/modules/scripts.py
index 99bf836a..7d9dd59f 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,6 +1,7 @@
import os
import re
import sys
+import inspect
from collections import namedtuple
import gradio as gr
@@ -116,6 +117,21 @@ class Script:
pass
+ def after_extra_networks_activate(self, p, *args, **kwargs):
+ """
+ Calledafter extra networks activation, before conds calculation
+ allow modification of the network after extra networks activation been applied
+ won't be call if p.disable_extra_networks
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ - extra_network_data - list of ExtraNetworkParams for current stage
+ """
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -186,6 +202,11 @@ class Script:
return f'script_{tabname}{title}_{item_id}'
+ def before_hr(self, p, *args):
+ """
+ This function is called before hires fix start.
+ """
+ pass
current_basedir = paths.script_path
@@ -249,7 +270,7 @@ def load_scripts():
def register_scripts_from_module(module):
for script_class in module.__dict__.values():
- if type(script_class) != type:
+ if not inspect.isclass(script_class):
continue
if issubclass(script_class, Script):
@@ -483,6 +504,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
+ def after_extra_networks_activate(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.after_extra_networks_activate(p, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
+
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
@@ -548,6 +577,15 @@ class ScriptRunner:
self.scripts[si].args_to = args_to
+ def before_hr(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_hr(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+
+
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6ff5d17d..653c4cc0 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -247,7 +247,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+ if not shared.opts.disable_mmap_load_safetensors:
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -585,7 +590,6 @@ def unload_model_weights(sd_model=None, info=None):
sd_model = None
gc.collect()
devices.torch_gc()
- torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.")
diff --git a/modules/shared.py b/modules/shared.py
index a0862055..48478a68 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,9 +1,11 @@
import datetime
import json
import os
+import re
import sys
import threading
import time
+import logging
import gradio as gr
import torch
@@ -18,6 +20,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import Optional
+log = logging.getLogger(__name__)
+
demo = None
parser = cmd_args.parser
@@ -144,12 +148,15 @@ class State:
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
+ log.info("Received restart request")
def skip(self):
self.skipped = True
+ log.info("Received skip request")
def interrupt(self):
self.interrupted = True
+ log.info("Received interrupt request")
def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@@ -173,7 +180,7 @@ class State:
return obj
- def begin(self):
+ def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
@@ -187,10 +194,13 @@ class State:
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
-
+ self.job = job
devices.torch_gc()
+ log.info("Starting job %s", job)
def end(self):
+ duration = time.time() - self.time_start
+ log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
@@ -311,6 +321,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+ "font": OptionInfo("", "Font for image grids that have text"),
+ "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
@@ -376,6 +390,7 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -470,7 +485,6 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
@@ -481,6 +495,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
@@ -493,6 +508,7 @@ options_templates.update(options_section(('ui', "User interface"), {
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+ "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
@@ -817,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
+def natural_sort_key(s, regex=re.compile('([0-9]+)')):
+ return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
+
+
def listfiles(dirname):
- filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
@@ -843,8 +863,11 @@ def walk_files(path, allowed_extensions=None):
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
- for root, _, files in os.walk(path, followlinks=True):
- for filename in files:
+ items = list(os.walk(path, followlinks=True))
+ items = sorted(items, key=lambda x: natural_sort_key(x[0]))
+
+ for root, _, files in items:
+ for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 734a4b6f..45823eb1 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,11 +2,51 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_shared = {
+ "batch_size",
+ "clip_grad_mode",
+ "clip_grad_value",
+ "create_image_every",
+ "data_root",
+ "gradient_step",
+ "initial_step",
+ "latent_sampling_method",
+ "learn_rate",
+ "log_directory",
+ "model_hash",
+ "model_name",
+ "num_of_dataset_images",
+ "steps",
+ "template_file",
+ "training_height",
+ "training_width",
+}
+saved_params_ti = {
+ "embedding_name",
+ "num_vectors_per_token",
+ "save_embedding_every",
+ "save_image_with_stored_embedding",
+}
+saved_params_hypernet = {
+ "activation_func",
+ "add_layer_norm",
+ "hypernetwork_name",
+ "layer_structure",
+ "save_hypernetwork_every",
+ "use_dropout",
+ "weight_init",
+}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+saved_params_previews = {
+ "preview_cfg_scale",
+ "preview_height",
+ "preview_negative_prompt",
+ "preview_prompt",
+ "preview_sampler_index",
+ "preview_seed",
+ "preview_steps",
+ "preview_width",
+}
def save_settings_to_file(log_directory, all_params):
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 0d4c3f84..dbd856bd 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru
from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 2e7d202d..6aa79f23 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic
from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.ui import plaintext_to_html
+import gradio as gr
-
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
+ p.user = request.username
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/ui.py b/modules/ui.py
index e2e3b6da..39d226ad 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
img = Image.open(image)
filename = os.path.basename(image)
left, _ = os.path.splitext(filename)
- print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
return [gr.update(), None]
@@ -733,6 +733,10 @@ def create_ui():
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ with gr.Accordion("PNG info", open=False):
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
@@ -773,7 +777,7 @@ def create_ui():
selected_scale_tab = gr.State(value=0)
with gr.Tabs():
- with gr.Tab(label="Resize to") as tab_scale_to:
+ with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@@ -782,7 +786,7 @@ def create_ui():
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
- with gr.Tab(label="Resize by") as tab_scale_by:
+ with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow():
@@ -934,6 +938,9 @@ def create_ui():
img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir,
override_settings,
+ img2img_batch_use_png_info,
+ img2img_batch_png_info_props,
+ img2img_batch_png_info_dir,
] + custom_inputs,
outputs=[
img2img_gallery,
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index c7e0a866..dff522ef 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -138,7 +138,10 @@ def extension_table():
<table id="extensions">
<thead>
<tr>
- <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+ <th>
+ <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
+ <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
+ </th>
<th>URL</th>
<th>Branch</th>
<th>Version</th>
@@ -170,7 +173,7 @@ def extension_table():
code += f"""
<tr>
- <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+ <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
<td>{remote}</td>
<td>{ext.branch}</td>
<td>{version_link}</td>
@@ -421,9 +424,19 @@ sort_ordering = [
(False, lambda x: x.get('name', 'z')),
(True, lambda x: x.get('name', 'z')),
(False, lambda x: 'z'),
+ (True, lambda x: x.get('commit_time', '')),
+ (True, lambda x: x.get('created_at', '')),
+ (True, lambda x: x.get('stars', 0)),
]
+def get_date(info: dict, key):
+ try:
+ return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
+ except (ValueError, TypeError):
+ return ''
+
+
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -448,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname")
+ stars = int(ext.get("stars", 0))
added = ext.get('added', 'unknown')
+ update_time = get_date(ext, 'commit_time')
+ create_time = get_date(ext, 'created_at')
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
@@ -475,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
- <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
+ <td>{html.escape(description)}<p class="info">
+ <span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
<td>{install_code}</td>
</tr>
@@ -559,7 +576,7 @@ def create_ui():
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
- sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False)
@@ -568,9 +585,9 @@ def create_ui():
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column],
- outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
)
install_extension_button.click(
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index a7d3bc79..693cafb6 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -30,8 +30,8 @@ def fetch_file(filename: str = ""):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".jpeg", ".webp"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+ if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -90,8 +90,8 @@ class ExtraNetworksPage:
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for root, dirs, _ in os.walk(parentdir, followlinks=True):
- for dirname in dirs:
+ for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
+ for dirname in sorted(dirs, key=shared.natural_sort_key):
x = os.path.join(root, dirname)
if not os.path.isdir(x):
diff --git a/style.css b/style.css
index e1df716f..5073f0f0 100644
--- a/style.css
+++ b/style.css
@@ -704,11 +704,24 @@ table.popup-table .link{
margin: 0;
}
-#available_extensions .date_added{
- opacity: 0.85;
+#available_extensions .info{
+ margin: 0.5em 0;
+ display: flex;
+ margin-top: auto;
+ opacity: 0.80;
font-size: 90%;
}
+#available_extensions .date_added{
+ margin-right: auto;
+ display: inline-block;
+}
+
+#available_extensions .star_count{
+ margin-left: auto;
+ display: inline-block;
+}
+
/* replace original footer with ours */
footer {
diff --git a/webui.py b/webui.py
index 136d036d..34c2fd18 100644
--- a/webui.py
+++ b/webui.py
@@ -11,13 +11,24 @@ import json
from threading import Thread
from typing import Iterable
-from fastapi import FastAPI, Response
+from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
+# We can't use cmd_opts for this because it will not have been initialized at this point.
+log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
+if log_level:
+ log_level = getattr(logging, log_level.upper(), None) or logging.INFO
+ logging.basicConfig(
+ level=log_level,
+ format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ )
+
+logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import paths, timer, import_hook, errors, devices # noqa: F401
@@ -32,7 +43,7 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
startup_timer.record("import torch")
-import gradio
+import gradio # noqa: F401
startup_timer.record("import gradio")
import ldm.modules.encoders.modules # noqa: F401
@@ -359,12 +370,11 @@ def api_only():
modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
- api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-
-
-def stop_route(request):
- shared.state.server_command = "stop"
- return Response("Stopping.")
+ api.launch(
+ server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
+ port=cmd_opts.port if cmd_opts.port else 7861,
+ root_path = f"/{cmd_opts.subpath}"
+ )
def webui():
@@ -403,9 +413,8 @@ def webui():
"docs_url": "/docs",
"redoc_url": "/redoc",
},
+ root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
)
- if cmd_opts.add_stop_route:
- app.add_route("/_stop", stop_route, methods=["POST"])
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
@@ -436,11 +445,6 @@ def webui():
timer.startup_record = startup_timer.dump()
print(f"Startup time: {startup_timer.summary()}.")
- if cmd_opts.subpath:
- redirector = FastAPI()
- redirector.get("/")
- gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
-
try:
while True:
server_command = shared.state.wait_for_server_command(timeout=5)
diff --git a/webui.sh b/webui.sh
index 5c8d977c..246381fc 100755
--- a/webui.sh
+++ b/webui.sh
@@ -4,26 +4,28 @@
# change the variables in webui-user.sh instead #
#################################################
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+
# If run from macOS, load defaults from webui-macos-env.sh
if [[ "$OSTYPE" == "darwin"* ]]; then
- if [[ -f webui-macos-env.sh ]]
+ if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
then
- source ./webui-macos-env.sh
+ source "$SCRIPT_DIR"/webui-macos-env.sh
fi
fi
# Read variables from webui-user.sh
# shellcheck source=/dev/null
-if [[ -f webui-user.sh ]]
+if [[ -f "$SCRIPT_DIR"/webui-user.sh ]]
then
- source ./webui-user.sh
+ source "$SCRIPT_DIR"/webui-user.sh
fi
# Set defaults
# Install directory without trailing slash
if [[ -z "${install_dir}" ]]
then
- install_dir="$(pwd)"
+ install_dir="$SCRIPT_DIR"
fi
# Name of the subdirectory (defaults to stable-diffusion-webui)