aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py16
-rw-r--r--modules/config_states.py200
-rw-r--r--modules/devices.py8
-rw-r--r--modules/extensions.py39
-rw-r--r--modules/extra_networks_hypernet.py2
-rw-r--r--modules/extras.py46
-rw-r--r--modules/generation_parameters_copypaste.py6
-rw-r--r--modules/images.py31
-rw-r--r--modules/img2img.py9
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/ngrok.py12
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/postprocessing.py9
-rw-r--r--modules/processing.py37
-rw-r--r--modules/realesrgan_model.py14
-rw-r--r--modules/safe.py5
-rw-r--r--modules/sd_models.py11
-rw-r--r--modules/sd_samplers_common.py10
-rw-r--r--modules/sd_samplers_kdiffusion.py46
-rw-r--r--modules/shared.py48
-rw-r--r--modules/styles.py12
-rw-r--r--modules/textual_inversion/preprocess.py4
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/ui.py47
-rw-r--r--modules/ui_common.py2
-rw-r--r--modules/ui_components.py10
-rw-r--r--modules/ui_extensions.py245
-rw-r--r--modules/ui_extra_networks.py2
-rw-r--r--modules/ui_postprocessing.py8
29 files changed, 788 insertions, 102 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 518b2a61..9ffcbd5f 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -6,7 +6,6 @@ import uvicorn
import gradio as gr
from threading import Lock
from io import BytesIO
-from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
@@ -272,7 +271,9 @@ class Api:
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]:
- script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
+ # min between arg length in scriptrunner and arg length in the request
+ for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
+ script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
@@ -395,16 +396,11 @@ class Api:
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
- def prepareFiles(file):
- file = decode_base64_to_file(file.data, file_path=file.name)
- file.orig_name = file.name
- return file
-
- reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
- reqDict.pop('imageList')
+ image_list = reqDict.pop('imageList', [])
+ image_folder = [decode_base64_to_image(x.data) for x in image_list]
with self.queue_lock:
- result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
diff --git a/modules/config_states.py b/modules/config_states.py
new file mode 100644
index 00000000..2ea00929
--- /dev/null
+++ b/modules/config_states.py
@@ -0,0 +1,200 @@
+"""
+Supports saving and restoring webui and extensions from a known working set of commits
+"""
+
+import os
+import sys
+import traceback
+import json
+import time
+import tqdm
+
+from datetime import datetime
+from collections import OrderedDict
+import git
+
+from modules import shared, extensions
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
+
+
+all_config_states = OrderedDict()
+
+
+def list_config_states():
+ global all_config_states
+
+ all_config_states.clear()
+ os.makedirs(config_states_dir, exist_ok=True)
+
+ config_states = []
+ for filename in os.listdir(config_states_dir):
+ if filename.endswith(".json"):
+ path = os.path.join(config_states_dir, filename)
+ with open(path, "r", encoding="utf-8") as f:
+ j = json.load(f)
+ j["filepath"] = path
+ config_states.append(j)
+
+ config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
+
+ for cs in config_states:
+ timestamp = time.asctime(time.gmtime(cs["created_at"]))
+ name = cs.get("name", "Config")
+ full_name = f"{name}: {timestamp}"
+ all_config_states[full_name] = cs
+
+ return all_config_states
+
+
+def get_webui_config():
+ webui_repo = None
+
+ try:
+ if os.path.exists(os.path.join(script_path, ".git")):
+ webui_repo = git.Repo(script_path)
+ except Exception:
+ print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ webui_remote = None
+ webui_commit_hash = None
+ webui_commit_date = None
+ webui_branch = None
+ if webui_repo and not webui_repo.bare:
+ try:
+ webui_remote = next(webui_repo.remote().urls, None)
+ head = webui_repo.head.commit
+ webui_commit_date = webui_repo.head.commit.committed_date
+ webui_commit_hash = head.hexsha
+ webui_branch = webui_repo.active_branch.name
+
+ except Exception:
+ webui_remote = None
+
+ return {
+ "remote": webui_remote,
+ "commit_hash": webui_commit_hash,
+ "commit_date": webui_commit_date,
+ "branch": webui_branch,
+ }
+
+
+def get_extension_config():
+ ext_config = {}
+
+ for ext in extensions.extensions:
+ entry = {
+ "name": ext.name,
+ "path": ext.path,
+ "enabled": ext.enabled,
+ "is_builtin": ext.is_builtin,
+ "remote": ext.remote,
+ "commit_hash": ext.commit_hash,
+ "commit_date": ext.commit_date,
+ "branch": ext.branch,
+ "have_info_from_repo": ext.have_info_from_repo
+ }
+
+ ext_config[ext.name] = entry
+
+ return ext_config
+
+
+def get_config():
+ creation_time = datetime.now().timestamp()
+ webui_config = get_webui_config()
+ ext_config = get_extension_config()
+
+ return {
+ "created_at": creation_time,
+ "webui": webui_config,
+ "extensions": ext_config
+ }
+
+
+def restore_webui_config(config):
+ print("* Restoring webui state...")
+
+ if "webui" not in config:
+ print("Error: No webui data saved to config")
+ return
+
+ webui_config = config["webui"]
+
+ if "commit_hash" not in webui_config:
+ print("Error: No commit saved to webui config")
+ return
+
+ webui_commit_hash = webui_config.get("commit_hash", None)
+ webui_repo = None
+
+ try:
+ if os.path.exists(os.path.join(script_path, ".git")):
+ webui_repo = git.Repo(script_path)
+ except Exception:
+ print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return
+
+ try:
+ webui_repo.git.fetch(all=True)
+ webui_repo.git.reset(webui_commit_hash, hard=True)
+ print(f"* Restored webui to commit {webui_commit_hash}.")
+ except Exception:
+ print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+def restore_extension_config(config):
+ print("* Restoring extension state...")
+
+ if "extensions" not in config:
+ print("Error: No extension data saved to config")
+ return
+
+ ext_config = config["extensions"]
+
+ results = []
+ disabled = []
+
+ for ext in tqdm.tqdm(extensions.extensions):
+ if ext.is_builtin:
+ continue
+
+ ext.read_info_from_repo()
+ current_commit = ext.commit_hash
+
+ if ext.name not in ext_config:
+ ext.disabled = True
+ disabled.append(ext.name)
+ results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
+ continue
+
+ entry = ext_config[ext.name]
+
+ if "commit_hash" in entry and entry["commit_hash"]:
+ try:
+ ext.fetch_and_reset_hard(entry["commit_hash"])
+ ext.read_info_from_repo()
+ if current_commit != entry["commit_hash"]:
+ results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
+ except Exception as ex:
+ results.append((ext, current_commit[:8], False, ex))
+ else:
+ results.append((ext, current_commit[:8], False, "No commit hash found in config"))
+
+ if not entry.get("enabled", False):
+ ext.disabled = True
+ disabled.append(ext.name)
+ else:
+ ext.disabled = False
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ print("* Finished restoring extensions. Results:")
+ for ext, prev_commit, success, result in results:
+ if success:
+ print(f" + {ext.name}: {prev_commit} -> {result}")
+ else:
+ print(f" ! {ext.name}: FAILURE ({result})")
diff --git a/modules/devices.py b/modules/devices.py
index 52c3e7cd..c705a3cb 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape):
+ from modules.shared import opts
+
torch.manual_seed(seed)
- if device.type == 'mps':
+ if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
- if device.type == 'mps':
+ from modules.shared import opts
+
+ if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
diff --git a/modules/extensions.py b/modules/extensions.py
index 3a7a0372..34d9d654 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -3,10 +3,11 @@ import sys
import traceback
import time
+from datetime import datetime
import git
from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
extensions = []
@@ -31,12 +32,15 @@ class Extension:
self.status = ''
self.can_update = False
self.is_builtin = is_builtin
+ self.commit_hash = ''
+ self.commit_date = None
self.version = ''
+ self.branch = None
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
- if self.have_info_from_repo:
+ if self.is_builtin or self.have_info_from_repo:
return
self.have_info_from_repo = True
@@ -56,10 +60,15 @@ class Extension:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit
- ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
- self.version = f'{head.hexsha[:8]} ({ts})'
-
- except Exception:
+ self.commit_date = repo.head.commit.committed_date
+ ts = time.asctime(time.gmtime(self.commit_date))
+ if repo.active_branch:
+ self.branch = repo.active_branch.name
+ self.commit_hash = head.hexsha
+ self.version = f'{self.commit_hash[:8]} ({ts})'
+
+ except Exception as ex:
+ print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
def list_files(self, subdir, extension):
@@ -82,18 +91,30 @@ class Extension:
for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
- self.status = "behind"
+ self.status = "new commits"
+ return
+
+ try:
+ origin = repo.rev_parse('origin')
+ if repo.head.commit != origin:
+ self.can_update = True
+ self.status = "behind HEAD"
return
+ except Exception:
+ self.can_update = False
+ self.status = "unknown (remote error)"
+ return
self.can_update = False
self.status = "latest"
- def fetch_and_reset_hard(self):
+ def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
- repo.git.reset('origin', hard=True)
+ repo.git.reset(commit, hard=True)
+ self.have_info_from_repo = False
def list_extensions():
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index d3a4d7ad..33d100dd 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
- if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
diff --git a/modules/extras.py b/modules/extras.py
index d8ece955..ff4e9c4e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,6 +1,7 @@
import os
import re
import shutil
+import json
import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -241,13 +242,54 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
+ metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+
+ if save_metadata:
+ merge_recipe = {
+ "type": "webui", # indicate this model was merged with webui's built-in merger
+ "primary_model_hash": primary_model_info.sha256,
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+ "interp_method": interp_method,
+ "multiplier": multiplier,
+ "save_as_half": save_as_half,
+ "custom_name": custom_name,
+ "config_source": config_source,
+ "bake_in_vae": bake_in_vae,
+ "discard_weights": discard_weights,
+ "is_inpainting": result_is_inpainting_model,
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
+ }
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+
+ def add_model_metadata(checkpoint_info):
+ checkpoint_info.calculate_shorthash()
+ metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ "name": checkpoint_info.name,
+ "legacy_hash": checkpoint_info.hash,
+ "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
+ }
+
+ metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+
+ add_model_metadata(primary_model_info)
+ if secondary_model_info:
+ add_model_metadata(secondary_model_info)
+ if tertiary_model_info:
+ add_model_metadata(tertiary_model_info)
+
+ metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
else:
torch.save(theta_0, output_modelname)
sd_models.list_models()
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+ if created_model:
+ created_model.calculate_shorthash()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 6df76858..99f1a0d3 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res)
+ # Missing RNG means the default was set, which is GPU RNG
+ if "RNG" not in res:
+ res["RNG"] = "GPU"
+
return res
@@ -304,6 +308,8 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ('RNG', 'randn_source'),
+ ('NGMS', 's_min_uncond'),
]
diff --git a/modules/images.py b/modules/images.py
index b3535070..fd173829 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -318,6 +318,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
+NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
def sanitize_filename_part(text, replace_spaces=True):
@@ -352,6 +353,10 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
+ 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
}
default_time_format = '%Y%m%d%H%M%S'
@@ -360,6 +365,22 @@ class FilenameGenerator:
self.seed = seed
self.prompt = prompt
self.image = image
+
+ def hasprompt(self, *args):
+ lower = self.prompt.lower()
+ if self.p is None or self.prompt is None:
+ return None
+ outres = ""
+ for arg in args:
+ if arg != "":
+ division = arg.split("|")
+ expected = division[0].lower()
+ default = division[1] if len(division) > 1 else ""
+ if lower.find(expected) >= 0:
+ outres = f'{outres}{expected}'
+ else:
+ outres = outres if default == "" else f'{outres}{default}'
+ return sanitize_filename_part(outres)
def prompt_no_style(self):
if self.p is None or self.prompt is None:
@@ -403,9 +424,9 @@ class FilenameGenerator:
for m in re_pattern.finditer(x):
text, pattern = m.groups()
- res += text
if pattern is None:
+ res += text
continue
pattern_args = []
@@ -426,11 +447,13 @@ class FilenameGenerator:
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- if replacement is not None:
- res += str(replacement)
+ if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
+ continue
+ elif replacement is not None:
+ res += text + str(replacement)
continue
- res += f'[{pattern}]'
+ res += f'{text}[{pattern}]'
return res
diff --git a/modules/img2img.py b/modules/img2img.py
index d54728b7..56c846d6 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -4,7 +4,7 @@ import sys
import traceback
import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
from modules import devices, sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
@@ -46,7 +46,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
if state.interrupted:
break
- img = Image.open(image)
+ try:
+ img = Image.open(image)
+ except UnidentifiedImageError:
+ continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
@@ -157,7 +160,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
override_settings=override_settings,
)
- p.scripts = modules.scripts.scripts_txt2img
+ p.scripts = modules.scripts.scripts_img2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts:
diff --git a/modules/interrogate.py b/modules/interrogate.py
index cbb80683..e1665708 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
category_types = ["artists", "flavors", "mediums", "movements"]
try:
- os.makedirs(tmpdir)
+ os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
@@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
- os.remove(tmpdir)
+ os.removedirs(tmpdir)
class InterrogateModels:
diff --git a/modules/ngrok.py b/modules/ngrok.py
index 3df2c06b..1ad7989b 100644
--- a/modules/ngrok.py
+++ b/modules/ngrok.py
@@ -13,6 +13,18 @@ def connect(token, port, region):
config = conf.PyngrokConfig(
auth_token=token, region=region
)
+
+ # Guard for existing tunnels
+ existing = ngrok.get_tunnels(pyngrok_config=config)
+ if existing:
+ for established in existing:
+ # Extra configuration in the case that the user is also using ngrok for other tunnels
+ if established.config['addr'][-4:] == str(port):
+ public_url = existing[0].public_url
+ print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
+ 'You can use this link after the launch is complete.')
+ return
+
try:
if account is None:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 926ec3bb..6765bafe 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
+config_states_dir = os.path.join(script_path, "config_states")
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 09d8e605..736315e2 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -18,9 +18,14 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode == 1:
for img in image_folder:
- image = Image.open(img)
+ if isinstance(img, Image.Image):
+ image = img
+ fn = ''
+ else:
+ image = Image.open(os.path.abspath(img.name))
+ fn = os.path.splitext(img.orig_name)[0]
image_data.append(image)
- image_names.append(os.path.splitext(img.orig_name)[0])
+ image_names.append(fn)
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'
diff --git a/modules/processing.py b/modules/processing.py
index 6d9c6a8d..a48fff99 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -3,6 +3,7 @@ import math
import os
import sys
import warnings
+import hashlib
import torch
import numpy as np
@@ -105,7 +106,7 @@ class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -140,6 +141,7 @@ class StableDiffusionProcessing:
self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
+ self.s_min_uncond = s_min_uncond or opts.s_min_uncond
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
@@ -162,6 +164,8 @@ class StableDiffusionProcessing:
self.all_seeds = None
self.all_subseeds = None
self.iteration = 0
+ self.is_hr_pass = False
+
@property
def sd_model(self):
@@ -476,6 +480,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "Init image hash": getattr(p, 'init_img_hash', None),
+ "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+ "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
}
generation_params.update(p.extra_generation_params)
@@ -639,8 +646,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
- c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
+ step_multiplier = 1
+ if not shared.opts.dont_fix_second_order_samplers_schedule:
+ try:
+ step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
+ except:
+ pass
+ uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
+ c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -670,6 +683,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
for i, x_sample in enumerate(x_samples_ddim):
+ p.batch_index = i
+
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
@@ -706,9 +721,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+ if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
+ image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
@@ -718,7 +733,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.return_mask:
output_images.append(image_mask)
-
+
if opts.return_mask_composite:
output_images.append(image_mask_composite)
@@ -871,6 +886,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr:
return samples
+ self.is_hr_pass = True
+
target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y
@@ -940,6 +957,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ self.is_hr_pass = False
+
return samples
@@ -1007,6 +1026,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = []
imgs = []
for img in self.init_images:
+
+ # Save init image
+ if opts.save_init_img:
+ self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+
image = images.flatten(img, opts.img2img_background_color)
if crop_region is None and self.resize_mode != 3:
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index aad4a629..d6079433 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -9,7 +9,7 @@ from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
-
+from modules import modelloader
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
@@ -23,7 +23,15 @@ class UpscalerRealESRGAN(Upscaler):
self.enable = True
self.scalers = []
scalers = self.load_models(path)
+
+ local_model_paths = self.find_models(ext_filter=[".pth"])
for scaler in scalers:
+ if scaler.local_data_path.startswith("http"):
+ filename = modelloader.friendly_name(scaler.local_data_path)
+ local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None)
+ if local:
+ scaler.local_data_path = local
+
if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler)
@@ -64,7 +72,9 @@ class UpscalerRealESRGAN(Upscaler):
print(f"Unable to find model info: {path}")
return None
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+ if info.local_data_path.startswith("http"):
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
diff --git a/modules/safe.py b/modules/safe.py
index 82d44be3..dadf319c 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -1,6 +1,5 @@
# this code is adapted from the script contributed by anon from /h/
-import io
import pickle
import collections
import sys
@@ -12,11 +11,9 @@ import _codecs
import zipfile
import re
-
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
-
def encode(*args):
out = _codecs.encode(*args)
return out
@@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
- return TypedStorage()
+ return TypedStorage(_internal=True)
def find_class(self, module, name):
if self.extra_handler is not None:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6ea874df..4f7613a1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -52,6 +52,15 @@ class CheckpointInfo:
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+ self.metadata = {}
+
+ _, ext = os.path.splitext(self.filename)
+ if ext.lower() == ".safetensors":
+ try:
+ self.metadata = read_metadata_from_safetensors(filename)
+ except Exception as e:
+ errors.display(e, f"reading checkpoint metadata: {filename}")
+
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
@@ -544,4 +553,4 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
- return sd_model \ No newline at end of file
+ return sd_model
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index a1aac7cf..bc074238 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -60,3 +60,13 @@ def store_latent(decoded):
class InterruptedException(BaseException):
pass
+
+
+if opts.randn_source == "CPU":
+ import torchsde._brownian.brownian_interval
+
+ def torchsde_randn(size, dtype, device, seed):
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+
+ torchsde._brownian.brownian_interval._randn = torchsde_randn
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index e9f08518..eb98e599 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module):
return denoised
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -115,12 +115,21 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
+ skip_uncond = False
- if tensor.shape[1] == uncond.shape[1]:
- if not is_edit_model:
- cond_in = torch.cat([tensor, uncond])
- else:
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
@@ -144,7 +153,13 @@ class CFGDenoiser(torch.nn.Module):
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
cfg_denoised_callback(denoised_params)
@@ -152,20 +167,21 @@ class CFGDenoiser(torch.nn.Module):
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
+ sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
- if not is_edit_model:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- else:
+ if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.step += 1
-
return denoised
@@ -190,7 +206,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
- if x.device.type == 'mps':
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
@@ -210,6 +226,7 @@ class KDiffusionSampler:
self.eta = None
self.config = None
self.last_latent = None
+ self.s_min_uncond = None
self.conditioning_key = sd_model.model.conditioning_key
@@ -244,6 +261,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -326,6 +344,7 @@ class KDiffusionSampler:
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
@@ -359,7 +378,8 @@ class KDiffusionSampler:
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
diff --git a/modules/shared.py b/modules/shared.py
index 5fd0eecb..6a2b3c2b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -4,6 +4,7 @@ import json
import os
import sys
import time
+import requests
from PIL import Image
import gradio as gr
@@ -39,6 +40,7 @@ restricted_opts = {
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
+ "outdir_init_images"
}
ui_reorder_categories = [
@@ -54,6 +56,21 @@ ui_reorder_categories = [
"scripts",
]
+# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
+gradio_hf_hub_themes = [
+ "gradio/glass",
+ "gradio/monochrome",
+ "gradio/seafoam",
+ "gradio/soft",
+ "freddyaboulton/dracula_revamped",
+ "gradio/dracula_test",
+ "abidlabs/dracula_test",
+ "abidlabs/pakistan",
+ "dawood/microsoft_windows",
+ "ysharma/steampunk"
+]
+
+
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
@@ -252,7 +269,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
- "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
+ "save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
@@ -268,6 +285,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
+ "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
@@ -283,6 +301,8 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
+ "SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -331,6 +351,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
+ "randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -338,6 +359,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
+ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
@@ -361,7 +383,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
@@ -382,11 +404,13 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
+ "gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
}))
options_templates.update(options_section(('ui', "Live previews"), {
@@ -405,6 +429,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
@@ -424,6 +449,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
+ "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
@@ -600,6 +626,24 @@ clip_model = None
progress_print_out = sys.stdout
+gradio_theme = gr.themes.Base()
+
+
+def reload_gradio_theme(theme_name=None):
+ global gradio_theme
+ if not theme_name:
+ theme_name = opts.gradio_theme
+
+ if theme_name == "Default":
+ gradio_theme = gr.themes.Default()
+ else:
+ try:
+ gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
+ except requests.exceptions.ConnectionError:
+ print("Can't access HuggingFace Hub, falling back to default Gradio theme")
+ gradio_theme = gr.themes.Default()
+
+
class TotalTQDM:
def __init__(self):
diff --git a/modules/styles.py b/modules/styles.py
index 990d5623..9ed85991 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -72,16 +72,14 @@ class StyleDatabase:
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
def save_styles(self, path: str) -> None:
- # Write to temporary file first, so we don't nuke the file if something goes wrong
- fd, temp_path = tempfile.mkstemp(".csv")
+ # Always keep a backup file around
+ if os.path.exists(path):
+ shutil.copy(path, path + ".bak")
+
+ fd = os.open(path, os.O_RDWR|os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items())
-
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.move(path, path + ".bak")
- shutil.move(temp_path, path)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 2239cb84..de1ddb59 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -161,7 +161,9 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
params.subindex = 0
filename = os.path.join(src, imagefile)
try:
- img = Image.open(filename).convert("RGB")
+ img = Image.open(filename)
+ img = ImageOps.exif_transpose(img)
+ img = img.convert("RGB")
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d2e62e58..379df243 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -233,6 +233,12 @@ class EmbeddingDatabase:
self.load_from_dir(embdir)
embdir.update()
+ # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
+ # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
+ sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
+ self.word_embeddings.clear()
+ self.word_embeddings.update(sorted_word_embeddings)
+
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
diff --git a/modules/ui.py b/modules/ui.py
index 43e36a0a..e00fbdeb 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -94,6 +94,9 @@ def send_gradio_gallery_to_image(x):
def visit(x, func, path=""):
if hasattr(x, 'children'):
+ if isinstance(x, gr.Tabs) and x.elem_id is not None:
+ # Tabs element can't have a label, have to use elem_id instead
+ func(f"{path}/Tabs@{x.elem_id}", x)
for c in x.children:
visit(c, func, path)
elif x.label is not None:
@@ -181,8 +184,8 @@ def create_seed_inputs(target_interface):
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
seed.style(container=False)
- random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
+ random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed')
+ reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
@@ -478,7 +481,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
@@ -1062,8 +1065,9 @@ def create_ui():
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
- checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
+ save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
with FormRow():
with gr.Column():
@@ -1091,7 +1095,7 @@ def create_ui():
with gr.Row(variant="compact").style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
- with gr.Tab(label="Create embedding"):
+ with gr.Tab(label="Create embedding", id="create_embedding"):
new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
@@ -1104,7 +1108,7 @@ def create_ui():
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
- with gr.Tab(label="Create hypernetwork"):
+ with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
@@ -1122,7 +1126,7 @@ def create_ui():
with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
- with gr.Tab(label="Preprocess images"):
+ with gr.Tab(label="Preprocess images", id="preprocess_images"):
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
@@ -1189,7 +1193,7 @@ def create_ui():
def get_textual_inversion_template_names():
return sorted([x for x in textual_inversion.textual_inversion_templates])
- with gr.Tab(label="Train"):
+ with gr.Tab(label="Train", id="train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with FormRow():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
@@ -1247,7 +1251,7 @@ def create_ui():
with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
- ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
@@ -1522,7 +1526,7 @@ def create_ui():
current_row.__exit__()
current_tab.__exit__()
- with gr.TabItem("Actions"):
+ with gr.TabItem("Actions", id="actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
@@ -1530,7 +1534,7 @@ def create_ui():
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
- with gr.TabItem("Licenses"):
+ with gr.TabItem("Licenses", id="licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@@ -1608,7 +1612,7 @@ def create_ui():
for _interface, label, _ifid in interfaces:
shared.tab_names.append(label)
- with gr.Blocks(analytics_enabled=False, title="Stable Diffusion") as demo:
+ with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
@@ -1671,6 +1675,7 @@ def create_ui():
fn=get_settings_values,
inputs=[],
outputs=[component_dict[k] for k in component_keys],
+ queue=False,
)
def modelmerger(*args):
@@ -1700,6 +1705,7 @@ def create_ui():
config_source,
bake_in_vae,
discard_weights,
+ save_metadata,
],
outputs=[
primary_model_name,
@@ -1747,7 +1753,7 @@ def create_ui():
if init_field is not None:
init_field(saved_value)
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:
@@ -1777,12 +1783,27 @@ def create_ui():
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+ def check_tab_id(tab_id):
+ tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
+ if type(tab_id) == str:
+ tab_ids = [t.id for t in tab_items]
+ return tab_id in tab_ids
+ elif type(tab_id) == int:
+ return tab_id >= 0 and tab_id < len(tab_items)
+ else:
+ return False
+
+ if type(x) == gr.Tabs:
+ apply_field(x, 'selected', check_tab_id)
+
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger")
visit(train_interface, loadsave, "train")
+ loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
+
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file:
json.dump(ui_settings, file, indent=4)
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 3b11dcc8..27ab3ebb 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -125,7 +125,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
generation_info = None
with gr.Column():
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 2b1da2cb..64451df7 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -62,3 +62,13 @@ class DropdownMulti(FormComponent, gr.Dropdown):
def get_block_name(self):
return "dropdown"
+
+
+class DropdownEditable(FormComponent, gr.Dropdown):
+ """Same as gr.Dropdown but allows editing value"""
+ def __init__(self, **kwargs):
+ super().__init__(allow_custom_value=True, **kwargs)
+
+ def get_block_name(self):
+ return "dropdown"
+
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index efd6cda2..99ac8756 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -2,6 +2,7 @@ import json
import os.path
import sys
import time
+from datetime import datetime
import traceback
import git
@@ -11,10 +12,12 @@ import html
import shutil
import errno
-from modules import extensions, shared, paths
+from modules import extensions, shared, paths, config_states
+from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
available_extensions = {"extensions": []}
+STYLE_PRIMARY = ' style="color: var(--primary-400)"'
def check_access():
@@ -30,6 +33,9 @@ def apply_and_restart(disable_list, update_list, disable_all):
update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+ if update:
+ save_config_state("Backup (pre-update)")
+
update = set(update)
for ext in extensions.extensions:
@@ -50,6 +56,46 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.state.need_restart = True
+def save_config_state(name):
+ current_config_state = config_states.get_config()
+ if not name:
+ name = "Config"
+ current_config_state["name"] = name
+ filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + name + ".json")
+ print(f"Saving backup of webui/extension state to {filename}.")
+ with open(filename, "w", encoding="utf-8") as f:
+ json.dump(current_config_state, f)
+ config_states.list_config_states()
+ new_value = next(iter(config_states.all_config_states.keys()), "Current")
+ new_choices = ["Current"] + list(config_states.all_config_states.keys())
+ return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
+
+
+def restore_config_state(confirmed, config_state_name, restore_type):
+ if config_state_name == "Current":
+ return "<span>Select a config to restore from.</span>"
+ if not confirmed:
+ return "<span>Cancelled.</span>"
+
+ check_access()
+
+ config_state = config_states.all_config_states[config_state_name]
+
+ print(f"*** Restoring webui state from backup: {restore_type} ***")
+
+ if restore_type == "extensions" or restore_type == "both":
+ shared.opts.restore_config_state_file = config_state["filepath"]
+ shared.opts.save(shared.config_filename)
+
+ if restore_type == "webui" or restore_type == "both":
+ config_states.restore_webui_config(config_state)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+ return ""
+
+
def check_updates(id_task, disable_list):
check_access()
@@ -76,6 +122,16 @@ def check_updates(id_task, disable_list):
return extension_table(), ""
+def make_commit_link(commit_hash, remote, text=None):
+ if text is None:
+ text = commit_hash[:8]
+ if remote.startswith("https://github.com/"):
+ href = os.path.join(remote, "commit", commit_hash)
+ return f'<a href="{href}" target="_blank">{text}</a>'
+ else:
+ return text
+
+
def extension_table():
code = f"""<!-- {time.time()} -->
<table id="extensions">
@@ -102,13 +158,17 @@ def extension_table():
style = ""
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
- style = ' style="color: var(--primary-400)"'
+ style = STYLE_PRIMARY
+
+ version_link = ext.version
+ if ext.commit_hash and ext.remote:
+ version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
code += f"""
<tr>
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td>
- <td>{ext.version}</td>
+ <td>{version_link}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -121,6 +181,133 @@ def extension_table():
return code
+def update_config_states_table(state_name):
+ if state_name == "Current":
+ config_state = config_states.get_config()
+ else:
+ config_state = config_states.all_config_states[state_name]
+
+ config_name = config_state.get("name", "Config")
+ created_date = time.asctime(time.gmtime(config_state["created_at"]))
+ filepath = config_state.get("filepath", "<unknown>")
+
+ code = f"""<!-- {time.time()} -->"""
+
+ webui_remote = config_state["webui"]["remote"] or ""
+ webui_branch = config_state["webui"]["branch"]
+ webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
+ webui_commit_date = config_state["webui"]["commit_date"]
+ if webui_commit_date:
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
+ else:
+ webui_commit_date = "<unknown>"
+
+ remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
+
+ current_webui = config_states.get_webui_config()
+
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if current_webui["remote"] != webui_remote:
+ style_remote = STYLE_PRIMARY
+ if current_webui["branch"] != webui_branch:
+ style_branch = STYLE_PRIMARY
+ if current_webui["commit_hash"] != webui_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f"""<h2>Config Backup: {config_name}</h2>
+ <div><b>Filepath:</b> {filepath}</div>
+ <div><b>Created at:</b> {created_date}</div>"""
+
+ code += f"""<h2>WebUI State</h2>
+ <table id="config_state_webui">
+ <thead>
+ <tr>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td><label{style_remote}>{remote}</label></td>
+ <td><label{style_branch}>{webui_branch}</label></td>
+ <td><label{style_commit}>{commit_link}</label></td>
+ <td><label{style_commit}>{date_link}</label></td>
+ </tr>
+ </tbody>
+ </table>
+ """
+
+ code += """<h2>Extension State</h2>
+ <table id="config_state_extensions">
+ <thead>
+ <tr>
+ <th>Extension</th>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+ """
+
+ ext_map = {ext.name: ext for ext in extensions.extensions}
+
+ for ext_name, ext_conf in config_state["extensions"].items():
+ ext_remote = ext_conf["remote"] or ""
+ ext_branch = ext_conf["branch"] or "<unknown>"
+ ext_enabled = ext_conf["enabled"]
+ ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
+ ext_commit_date = ext_conf["commit_date"]
+ if ext_commit_date:
+ ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ else:
+ ext_commit_date = "<unknown>"
+
+ remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
+ commit_link = make_commit_link(ext_commit_hash, ext_remote)
+ date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+
+ style_enabled = ""
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if ext_name in ext_map:
+ current_ext = ext_map[ext_name]
+ current_ext.read_info_from_repo()
+ if current_ext.enabled != ext_enabled:
+ style_enabled = STYLE_PRIMARY
+ if current_ext.remote != ext_remote:
+ style_remote = STYLE_PRIMARY
+ if current_ext.branch != ext_branch:
+ style_branch = STYLE_PRIMARY
+ if current_ext.commit_hash != ext_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f"""
+ <tr>
+ <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
+ <td><label{style_remote}>{remote}</label></td>
+ <td><label{style_branch}>{ext_branch}</label></td>
+ <td><label{style_commit}>{commit_link}</label></td>
+ <td><label{style_commit}>{date_link}</label></td>
+ </tr>
+ """
+
+ code += """
+ </tbody>
+ </table>
+ """
+
+ return code
+
+
def normalize_git_url(url):
if url is None:
return ""
@@ -129,7 +316,7 @@ def normalize_git_url(url):
return url
-def install_extension_from_url(dirname, url):
+def install_extension_from_url(dirname, url, branch_name=None):
check_access()
assert url, 'No URL specified'
@@ -150,10 +337,17 @@ def install_extension_from_url(dirname, url):
try:
shutil.rmtree(tmpdir, True)
- with git.Repo.clone_from(url, tmpdir) as repo:
- repo.remote().fetch()
- for submodule in repo.submodules:
- submodule.update()
+ if not branch_name:
+ # if no branch is specified, use the default branch
+ with git.Repo.clone_from(url, tmpdir) as repo:
+ repo.remote().fetch()
+ for submodule in repo.submodules:
+ submodule.update()
+ else:
+ with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
+ repo.remote().fetch()
+ for submodule in repo.submodules:
+ submodule.update()
try:
os.rename(tmpdir, target_dir)
except OSError as err:
@@ -292,9 +486,11 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
def create_ui():
import modules.ui
+ config_states.list_config_states()
+
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
- with gr.TabItem("Installed"):
+ with gr.TabItem("Installed", id="installed"):
with gr.Row(elem_id="extensions_installed_top"):
apply = gr.Button(value="Apply and restart UI", variant="primary")
@@ -327,7 +523,7 @@ def create_ui():
outputs=[extensions_table, info],
)
- with gr.TabItem("Available"):
+ with gr.TabItem("Available", id="available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
@@ -374,16 +570,41 @@ def create_ui():
outputs=[available_extensions_table, install_result]
)
- with gr.TabItem("Install from URL"):
+ with gr.TabItem("Install from URL", id="install_from_url"):
install_url = gr.Text(label="URL for extension's git repository")
+ install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary")
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
- inputs=[install_dirname, install_url],
+ inputs=[install_dirname, install_url, install_branch],
outputs=[extensions_table, install_result],
)
+ with gr.TabItem("Backup/Restore"):
+ with gr.Row(elem_id="extensions_backup_top_row"):
+ config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
+ modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
+ config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
+ config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
+ with gr.Row(elem_id="extensions_backup_top_row2"):
+ config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
+ config_save_button = gr.Button(value="Save Current Config")
+
+ config_states_info = gr.HTML("")
+ config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
+
+ config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
+
+ dummy_component = gr.Label(visible=False)
+ config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
+
+ config_states_list.change(
+ fn=update_config_states_table,
+ inputs=[config_states_list],
+ outputs=[config_states_table],
+ )
+
return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 25eb464b..aa2f5d1b 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -241,7 +241,7 @@ def create_ui(container, button, tabname):
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages:
- with gr.Tab(page.title):
+ with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")):
page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem)
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index b418d955..f25639e5 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -9,13 +9,13 @@ def create_ui():
with gr.Row().style(equal_height=False, variant='compact'):
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
+ with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
- with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
+ with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
- with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")