aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--javascript/hints.js11
-rw-r--r--javascript/ui.js14
-rw-r--r--modules/api/api.py9
-rw-r--r--modules/call_queue.py4
-rw-r--r--modules/cmd_args.py1
-rw-r--r--modules/devices.py83
-rw-r--r--modules/errors.py3
-rw-r--r--modules/extras.py39
-rw-r--r--modules/img2img.py32
-rw-r--r--modules/processing.py7
-rw-r--r--modules/prompt_parser.py16
-rw-r--r--modules/rng_philox.py102
-rw-r--r--modules/scripts.py14
-rw-r--r--modules/sd_disable_initialization.py106
-rw-r--r--modules/sd_hijack.py8
-rw-r--r--modules/sd_hijack_clip.py4
-rw-r--r--modules/sd_hijack_optimizations.py4
-rw-r--r--modules/sd_models.py46
-rw-r--r--modules/sd_models_xl.py9
-rw-r--r--modules/sd_samplers_common.py12
-rw-r--r--modules/sd_samplers_extra.py74
-rw-r--r--modules/sd_samplers_kdiffusion.py13
-rw-r--r--modules/shared.py2
-rw-r--r--modules/styles.py5
-rw-r--r--modules/sysinfo.py6
-rw-r--r--modules/textual_inversion/textual_inversion.py19
-rw-r--r--modules/ui.py409
-rw-r--r--modules/ui_checkpoint_merger.py124
-rw-r--r--modules/ui_common.py32
-rw-r--r--modules/ui_extra_networks_checkpoints.py3
-rw-r--r--modules/ui_extra_networks_hypernets.py2
-rw-r--r--modules/ui_extra_networks_textual_inversion.py2
-rw-r--r--modules/ui_extra_networks_user_metadata.py1
-rw-r--r--modules/ui_prompt_styles.py110
-rw-r--r--scripts/xyz_grid.py14
-rw-r--r--style.css13
-rw-r--r--webui.py8
38 files changed, 959 insertions, 405 deletions
diff --git a/README.md b/README.md
index b796d150..2fd6e425 100644
--- a/README.md
+++ b/README.md
@@ -88,7 +88,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
- Now without any bad letters!
- Load checkpoints in safetensors format
-- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
+- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen
@@ -169,5 +169,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
+- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/javascript/hints.js b/javascript/hints.js
index 4167cb28..6de9372e 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
}
});
+
+onUiLoaded(function() {
+ for (var comp of window.gradio_config.components) {
+ if (comp.props.webui_tooltip && comp.props.elem_id) {
+ var elem = gradioApp().getElementById(comp.props.elem_id);
+ if (elem) {
+ elem.title = comp.props.webui_tooltip;
+ }
+ }
+ }
+});
diff --git a/javascript/ui.js b/javascript/ui.js
index d70a681b..abf23a78 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -152,7 +152,11 @@ function submit() {
showSubmitButtons('txt2img', false);
var id = randomId();
- localStorage.setItem("txt2img_task_id", id);
+ try {
+ localStorage.setItem("txt2img_task_id", id);
+ } catch (e) {
+ console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
+ }
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
showSubmitButtons('txt2img', true);
@@ -171,7 +175,11 @@ function submit_img2img() {
showSubmitButtons('img2img', false);
var id = randomId();
- localStorage.setItem("img2img_task_id", id);
+ try {
+ localStorage.setItem("img2img_task_id", id);
+ } catch (e) {
+ console.warn(`Failed to save img2img task id to localStorage: ${e}`);
+ }
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
showSubmitButtons('img2img', true);
@@ -191,8 +199,6 @@ function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false);
var id = localStorage.getItem("txt2img_task_id");
- id = localStorage.getItem("txt2img_task_id");
-
if (id) {
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
showSubmitButtons('txt2img', true);
diff --git a/modules/api/api.py b/modules/api/api.py
index 606db179..908c4514 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -197,6 +197,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
+ self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
@@ -343,6 +344,7 @@ class Api:
processed = process_images(p)
finally:
shared.state.end()
+ shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -402,6 +404,7 @@ class Api:
processed = process_images(p)
finally:
shared.state.end()
+ shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -608,6 +611,10 @@ class Api:
with self.queue_lock:
shared.refresh_checkpoints()
+ def refresh_vae(self):
+ with self.queue_lock:
+ shared_items.refresh_vae_list()
+
def create_embedding(self, args: dict):
try:
shared.state.begin(job="create_embedding")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 61aa240f..f2eb17d6 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -3,7 +3,7 @@ import html
import threading
import time
-from modules import shared, progress, errors
+from modules import shared, progress, errors, devices
queue_lock = threading.Lock()
@@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
+ devices.torch_gc()
+
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 1262f1a4..64f21e01 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
diff --git a/modules/devices.py b/modules/devices.py
index 57e51da3..00a00b18 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors
+from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True
-
errors.run(enable_tf32, "Enabling TF32")
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -90,23 +93,87 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
+nv_rng = None
+
+
def randn(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
from modules.shared import opts
- torch.manual_seed(seed)
+ manual_seed(seed)
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def randn_local(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ rng = rng_philox.Generator(seed)
+ return torch.asarray(rng.randn(shape), device=device)
+
+ local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+ local_generator = torch.Generator(local_device).manual_seed(int(seed))
+ return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
+def randn_like(x):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
+ return torch.randn_like(x, device=cpu).to(x.device)
+
+ return torch.randn_like(x)
+
+
def randn_without_seed(shape):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
from modules.shared import opts
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def manual_seed(seed):
+ """Set up a global random number generator using the specified seed."""
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ global nv_rng
+ nv_rng = rng_philox.Generator(seed)
+ return
+
+ torch.manual_seed(seed)
+
+
def autocast(disable=False):
from modules import shared
diff --git a/modules/errors.py b/modules/errors.py
index 5271a9fe..dffabe45 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -14,7 +14,8 @@ def record_exception():
if exception_records and exception_records[-1] == e:
return
- exception_records.append((e, tb))
+ from modules import sysinfo
+ exception_records.append(sysinfo.format_exception(e, tb))
if len(exception_records) > 5:
exception_records.pop(0)
diff --git a/modules/extras.py b/modules/extras.py
index e9c0263e..2a310ae3 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -7,7 +7,7 @@ import json
import torch
import tqdm
-from modules import shared, images, sd_models, sd_vae, sd_models_config
+from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
+def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
+ metadata = {}
+
+ for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
+ checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
+ if checkpoint_info is None:
+ continue
+
+ metadata.update(checkpoint_info.metadata)
+
+ return json.dumps(metadata, indent=4, ensure_ascii=False)
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
shared.state.begin(job="model-merge")
def fail(message):
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = None
+ metadata = {}
+
+ if save_metadata and copy_metadata_fields:
+ if primary_model_info:
+ metadata.update(primary_model_info.metadata)
+ if secondary_model_info:
+ metadata.update(secondary_model_info.metadata)
+ if tertiary_model_info:
+ metadata.update(tertiary_model_info.metadata)
if save_metadata:
- metadata = {"format": "pt"}
+ try:
+ metadata.update(json.loads(metadata_json))
+ except Exception as e:
+ errors.display(e, "readin metadata from json")
+
+ metadata["format"] = "pt"
+ if save_metadata and add_merge_recipe:
merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256,
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
"is_inpainting": result_is_inpainting_model,
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
}
- metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
sd_merge_models = {}
@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info:
add_model_metadata(tertiary_model_info)
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
else:
torch.save(theta_0, output_modelname)
diff --git a/modules/img2img.py b/modules/img2img.py
index a811e7a4..68e415ef 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -10,7 +10,6 @@ from modules import sd_samplers, images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
-from modules.images import save_image
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
@@ -18,9 +17,10 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
+ output_dir = output_dir.strip()
processing.fix_seed(p)
- images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
+ images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
is_inpaint_batch = False
if inpaint_mask_dir:
@@ -32,11 +32,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
- save_normally = output_dir == ''
-
- p.do_not_save_grid = True
- p.do_not_save_samples = not save_normally
-
state.job_count = len(images) * p.n_iter
# extract "default" params to use in case getting png info fails
@@ -111,21 +106,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
- proc = process_images(p)
-
- for n, processed_image in enumerate(proc.images):
- filename = image_path.stem
- infotext = proc.infotext(p, n)
- relpath = os.path.dirname(os.path.relpath(image, input_dir))
-
- if n > 0:
- filename += f"-{n}"
-
- if not save_normally:
- os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
- if processed_image.mode == 'RGBA':
- processed_image = processed_image.convert("RGB")
- save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
+ if output_dir:
+ p.outpath_samples = output_dir
+ p.override_settings['save_to_dirs'] = False
+ if p.n_iter > 1 or p.batch_size > 1:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
+ else:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
+ process_images(p)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
diff --git a/modules/processing.py b/modules/processing.py
index b0992ee1..8f34c8b4 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -492,7 +492,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
subnoise = None
- if subseeds is not None:
+ if subseeds is not None and subseed_strength != 0:
subseed = 0 if i >= len(subseeds) else subseeds[i]
subnoise = devices.randn(subseed, noise_shape)
@@ -524,7 +524,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
cnt = p.sampler.number_of_needed_noises(p)
if eta_noise_seed_delta > 0:
- torch.manual_seed(seed + eta_noise_seed_delta)
+ devices.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -636,7 +636,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
- "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+ "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
@@ -1348,6 +1348,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = image.to(shared.device, dtype=devices.dtype_vae)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ devices.torch_gc()
if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 203ae1ac..8169a459 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -19,7 +19,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
alternate: "[" prompt ("|" prompt)+ "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/
@@ -60,11 +60,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
- tree.children[-1] = float(tree.children[-1])
- if tree.children[-1] < 1:
- tree.children[-1] *= steps
- tree.children[-1] = min(steps, int(tree.children[-1]))
- res.append(tree.children[-1])
+ tree.children[-2] = float(tree.children[-2])
+ if tree.children[-2] < 1:
+ tree.children[-2] *= steps
+ tree.children[-2] = min(steps, int(tree.children[-2]))
+ res.append(tree.children[-2])
def alternate(self, tree):
res.extend(range(1, steps+1))
@@ -75,7 +75,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def at_step(step, tree):
class AtStep(lark.Transformer):
def scheduled(self, args):
- before, after, _, when = args
+ before, after, _, when, _ = args
yield before or () if step <= when else after
def alternate(self, args):
yield next(args[(step - 1)%len(args)])
@@ -333,7 +333,7 @@ re_attention = re.compile(r"""
\\|
\(|
\[|
-:([+-]?[.\d]+)\)|
+:\s*([+-]?[.\d]+)\s*\)|
\)|
]|
[^\\()\[\]:]+|
diff --git a/modules/rng_philox.py b/modules/rng_philox.py
new file mode 100644
index 00000000..5532cf9d
--- /dev/null
+++ b/modules/rng_philox.py
@@ -0,0 +1,102 @@
+"""RNG imitiating torch cuda randn on CPU. You are welcome.
+
+Usage:
+
+```
+g = Generator(seed=0)
+print(g.randn(shape=(3, 4)))
+```
+
+Expected output:
+```
+[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
+ [-0.12086647 -0.57972564 -0.62285122 -0.32838709]
+ [-1.07454231 -0.36314407 -1.67105067 2.26550497]]
+```
+"""
+
+import numpy as np
+
+philox_m = [0xD2511F53, 0xCD9E8D57]
+philox_w = [0x9E3779B9, 0xBB67AE85]
+
+two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
+two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
+
+
+def uint32(x):
+ """Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
+ return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
+
+
+def philox4_round(counter, key):
+ """A single round of the Philox 4x32 random number generator."""
+
+ v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
+ v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
+
+ counter[0] = v2[1] ^ counter[1] ^ key[0]
+ counter[1] = v2[0]
+ counter[2] = v1[1] ^ counter[3] ^ key[1]
+ counter[3] = v1[0]
+
+
+def philox4_32(counter, key, rounds=10):
+ """Generates 32-bit random numbers using the Philox 4x32 random number generator.
+
+ Parameters:
+ counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
+ key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
+ rounds (int): The number of rounds to perform.
+
+ Returns:
+ numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
+ """
+
+ for _ in range(rounds - 1):
+ philox4_round(counter, key)
+
+ key[0] = key[0] + philox_w[0]
+ key[1] = key[1] + philox_w[1]
+
+ philox4_round(counter, key)
+ return counter
+
+
+def box_muller(x, y):
+ """Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
+ u = x * two_pow32_inv + two_pow32_inv / 2
+ v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
+
+ s = np.sqrt(-2.0 * np.log(u))
+
+ r1 = s * np.sin(v)
+ return r1.astype(np.float32)
+
+
+class Generator:
+ """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
+
+ def __init__(self, seed):
+ self.seed = seed
+ self.offset = 0
+
+ def randn(self, shape):
+ """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
+
+ n = 1
+ for x in shape:
+ n *= x
+
+ counter = np.zeros((4, n), dtype=np.uint32)
+ counter[0] = self.offset
+ counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
+ self.offset += 1
+
+ key = np.empty(n, dtype=np.uint64)
+ key.fill(self.seed)
+ key = uint32(key)
+
+ g = philox4_32(counter, key)
+
+ return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
diff --git a/modules/scripts.py b/modules/scripts.py
index 5b4edcac..edf7347e 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -646,6 +646,8 @@ def add_classes_to_gradio_component(comp):
def IOComponent_init(self, *args, **kwargs):
+ self.webui_tooltip = kwargs.pop('tooltip', None)
+
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
@@ -663,8 +665,20 @@ def IOComponent_init(self, *args, **kwargs):
return res
+def Block_get_config(self):
+ config = original_Block_get_config(self)
+
+ webui_tooltip = getattr(self, 'webui_tooltip', None)
+ if webui_tooltip:
+ config["webui_tooltip"] = webui_tooltip
+
+ return config
+
+
original_IOComponent_init = gr.components.IOComponent.__init__
+original_Block_get_config = gr.components.Block.get_config
gr.components.IOComponent.__init__ = IOComponent_init
+gr.components.Block.get_config = Block_get_config
def BlockContext_init(self, *args, **kwargs):
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 9fc89dc6..695c5736 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -3,8 +3,31 @@ import open_clip
import torch
import transformers.utils.hub
+from modules import shared
-class DisableInitialization:
+
+class ReplaceHelper:
+ def __init__(self):
+ self.replaced = []
+
+ def replace(self, obj, field, func):
+ original = getattr(obj, field, None)
+ if original is None:
+ return None
+
+ self.replaced.append((obj, field, original))
+ setattr(obj, field, func)
+
+ return original
+
+ def restore(self):
+ for obj, field, original in self.replaced:
+ setattr(obj, field, original)
+
+ self.replaced.clear()
+
+
+class DisableInitialization(ReplaceHelper):
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
@@ -21,7 +44,7 @@ class DisableInitialization:
"""
def __init__(self, disable_clip=True):
- self.replaced = []
+ super().__init__()
self.disable_clip = disable_clip
def replace(self, obj, field, func):
@@ -86,8 +109,81 @@ class DisableInitialization:
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
- for obj, field, original in self.replaced:
- setattr(obj, field, original)
+ self.restore()
- self.replaced.clear()
+class InitializeOnMeta(ReplaceHelper):
+ """
+ Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
+ which results in those parameters having no values and taking no memory. model.to() will be broken and
+ will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
+
+ Usage:
+ ```
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+ ```
+ """
+
+ def __enter__(self):
+ if shared.cmd_opts.disable_model_loading_ram_optimization:
+ return
+
+ def set_device(x):
+ x["device"] = "meta"
+ return x
+
+ linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
+ conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
+ mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
+ self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.restore()
+
+
+class LoadStateDictOnMeta(ReplaceHelper):
+ """
+ Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
+ As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
+ Meant to be used together with InitializeOnMeta above.
+
+ Usage:
+ ```
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
+ model.load_state_dict(state_dict, strict=False)
+ ```
+ """
+
+ def __init__(self, state_dict, device):
+ super().__init__()
+ self.state_dict = state_dict
+ self.device = device
+
+ def __enter__(self):
+ if shared.cmd_opts.disable_model_loading_ram_optimization:
+ return
+
+ sd = self.state_dict
+ device = self.device
+
+ def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
+ params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
+
+ for name, param in params:
+ if param.is_meta:
+ self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
+
+ original(self, state_dict, prefix, *args, **kwargs)
+
+ for name, _ in params:
+ key = prefix + name
+ if key in sd:
+ del sd[key]
+
+ linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
+ conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
+ mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.restore()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index c8fdd4f1..cfa5f0eb 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
- embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
@@ -292,10 +292,11 @@ class StableDiffusionModelHijack:
class EmbeddingsWithFixes(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
+ def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
+ self.textual_inversion_key = textual_inversion_key
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
@@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
- emb = devices.cond_cast_unet(embedding.vec)
+ vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
+ emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 16a5500e..8f29057a 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += 1
continue
- emb_len = int(embedding.vec.shape[0])
+ emb_len = int(embedding.vectors)
if len(chunk.tokens) + emb_len > self.chunk_length:
next_chunk()
@@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
hashes.append(f"{name}: {shorthash}")
if hashes:
+ if self.hijack.extra_generation_params.get("TI hashes"):
+ hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
if getattr(self.wrapped, 'return_pooled', False):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index b5f85ba5..0e810eec 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ slice_size = q.shape[1] // steps
for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
+ end = min(i + slice_size, q.shape[1])
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index fb31a793..1d93d893 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -33,6 +33,8 @@ class CheckpointInfo:
self.filename = filename
abspath = os.path.abspath(filename)
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
@@ -43,6 +45,19 @@ class CheckpointInfo:
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
+ def read_metadata():
+ metadata = read_metadata_from_safetensors(filename)
+ self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
+
+ return metadata
+
+ self.metadata = {}
+ if self.is_safetensors:
+ try:
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
+ except Exception as e:
+ errors.display(e, f"reading metadata for {filename}")
+
self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
@@ -53,16 +68,7 @@ class CheckpointInfo:
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
-
- self.metadata = {}
-
- _, ext = os.path.splitext(self.filename)
- if ext.lower() == ".safetensors":
- try:
- self.metadata = read_metadata_from_safetensors(filename)
- except Exception as e:
- errors.display(e, f"reading checkpoint metadata: {filename}")
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -79,7 +85,7 @@ class CheckpointInfo:
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
- checkpoints_list.pop(self.title)
+ checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
self.register()
@@ -460,7 +466,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
-
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -495,19 +500,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
- sd_model = instantiate_from_config(sd_config.model)
- except Exception:
- pass
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+
+ except Exception as e:
+ errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
- sd_model = instantiate_from_config(sd_config.model)
+
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 40559208..bc219508 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
return torch.cat(res, dim=1)
+def tokenize(self: sgm.modules.GeneralConditioner, texts):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
+ return embedder.tokenize(texts)
+
+ raise AssertionError('no tokenizer available')
+
+
+
def process_texts(self, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
return embedder.process_texts(texts)
@@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
+sgm.modules.GeneralConditioner.tokenize = tokenize
sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 763829f1..5deda761 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,10 +2,8 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
-
+from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
-import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -85,11 +83,13 @@ class InterruptedException(BaseException):
pass
-if opts.randn_source == "CPU":
+def replace_torchsde_browinan():
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
- generator = torch.Generator(devices.cpu).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ return devices.randn_local(seed, size).to(device=device, dtype=dtype)
torchsde._brownian.brownian_interval._randn = torchsde_randn
+
+
+replace_torchsde_browinan()
diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py
new file mode 100644
index 00000000..1b981ca8
--- /dev/null
+++ b/modules/sd_samplers_extra.py
@@ -0,0 +1,74 @@
+import torch
+import tqdm
+import k_diffusion.sampling
+
+
+@torch.no_grad()
+def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
+ """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
+ Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
+ If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
+ """
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ step_id = 0
+ from k_diffusion.sampling import to_d, get_sigmas_karras
+
+ def heun_step(x, old_sigma, new_sigma, second_order=True):
+ nonlocal step_id
+ denoised = model(x, old_sigma * s_in, **extra_args)
+ d = to_d(x, old_sigma, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
+ dt = new_sigma - old_sigma
+ if new_sigma == 0 or not second_order:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
+ d_2 = to_d(x_2, new_sigma, denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ step_id += 1
+ return x
+
+ steps = sigmas.shape[0] - 1
+ if restart_list is None:
+ if steps >= 20:
+ restart_steps = 9
+ restart_times = 1
+ if steps >= 36:
+ restart_steps = steps // 4
+ restart_times = 2
+ sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
+ restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
+ else:
+ restart_list = {}
+
+ restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
+
+ step_list = []
+ for i in range(len(sigmas) - 1):
+ step_list.append((sigmas[i], sigmas[i + 1]))
+ if i + 1 in restart_list:
+ restart_steps, restart_times, restart_max = restart_list[i + 1]
+ min_idx = i + 1
+ max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
+ if max_idx < min_idx:
+ sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
+ while restart_times > 0:
+ restart_times -= 1
+ step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
+
+ last_sigma = None
+ for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
+ if last_sigma is None:
+ last_sigma = old_sigma
+ elif last_sigma < old_sigma:
+ x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
+ x = heun_step(x, old_sigma, new_sigma)
+ last_sigma = new_sigma
+
+ return x
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 5552a8dc..d72c1b5f 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -2,7 +2,7 @@ from collections import deque
import torch
import inspect
import k_diffusion.sampling
-from modules import prompt_parser, devices, sd_samplers_common
+from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
from modules.shared import opts, state
import modules.shared as shared
@@ -30,12 +30,14 @@ samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
+ ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
]
+
samplers_data_k_diffusion = [
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
for label, funcname, aliases, options in samplers_k_diffusion
- if hasattr(k_diffusion.sampling, funcname)
+ if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
]
sampler_extra_params = {
@@ -258,10 +260,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
- if opts.randn_source == "CPU" or x.device.type == 'mps':
- return torch.randn_like(x, device=devices.cpu).to(x.device)
- else:
- return torch.randn_like(x)
+ return devices.randn_like(x)
class KDiffusionSampler:
@@ -270,7 +269,7 @@ class KDiffusionSampler:
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
- self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
diff --git a/modules/shared.py b/modules/shared.py
index aa72c9c8..7103b4ca 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -428,7 +428,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
- "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
diff --git a/modules/styles.py b/modules/styles.py
index ec0e1bc5..0740fe1b 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -106,10 +106,7 @@ class StyleDatabase:
if os.path.exists(path):
shutil.copy(path, f"{path}.bak")
- fd = os.open(path, os.O_RDWR | os.O_CREAT)
- with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
- # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
- # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+ with open(path, "w", encoding="utf-8-sig", newline='') as file:
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items())
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
index 5f15ac4f..cf24c6dd 100644
--- a/modules/sysinfo.py
+++ b/modules/sysinfo.py
@@ -109,11 +109,15 @@ def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
+def format_exception(e, tb):
+ return {"exception": str(e), "traceback": format_traceback(tb)}
+
+
def get_exceptions():
try:
from modules import errors
- return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
+ return list(reversed(errors.exception_records))
except Exception as e:
return str(e)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6166c76f..4713bc2d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -181,29 +181,38 @@ class EmbeddingDatabase:
else:
return
+
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
+ vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+ vectors = data['clip_g'].shape[0]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
- vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
+ embedding.vectors = vectors
+ embedding.shape = shape
embedding.filename = path
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
diff --git a/modules/ui.py b/modules/ui.py
index 07ecee7b..03306ba9 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -12,30 +12,24 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
from modules.ui_gradio_extensions import reload_javascript
-
from modules.shared import opts, cmd_opts
-import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
-import modules.gfpgan_model
-import modules.hypernetworks.ui
-import modules.scripts
+import modules.hypernetworks.ui as hypernetworks_ui
+import modules.textual_inversion.ui as textual_inversion_ui
+import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
-import modules.styles
-import modules.textual_inversion.ui
+import modules.images
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.textual_inversion import textual_inversion
-import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
-import modules.extras
create_setting_component = ui_settings.create_setting_component
@@ -92,19 +86,6 @@ def send_gradio_gallery_to_image(x):
return image_from_url_text(x[0])
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
-
-
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
from modules import processing, devices
@@ -129,13 +110,6 @@ def resize_from_to_html(width, height, scale_by):
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
-def apply_styles(prompt, prompt_neg, styles):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
-
-
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
if mode in {0, 1, 3, 4}:
return [interrogation_function(ii_singles[mode]), None]
@@ -267,71 +241,78 @@ def update_token_counter(text, steps):
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
+class Toprow:
+ """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
- with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
- with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ def __init__(self, is_img2img):
+ id_part = "img2img" if is_img2img else "txt2img"
+ self.id_part = id_part
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_classes="interrogate-col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
- with gr.Row(elem_id=f"{id_part}_tools"):
- paste = ToolButton(value=paste_symbol, elem_id="paste")
- clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
- prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
- save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
- restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
-
- token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
- negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
- with gr.Row(elem_id=f"{id_part}_styles_row"):
- prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
- create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
+ self.button_interrogate = None
+ self.button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_classes="interrogate-col"):
+ self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
+ with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
+ with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+ self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
+ self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
+ self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+ self.skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ self.interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ with gr.Row(elem_id=f"{id_part}_tools"):
+ self.paste = ToolButton(value=paste_symbol, elem_id="paste")
+ self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+ self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
+
+ self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
+ self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
+ self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+ self.clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[self.prompt, self.negative_prompt],
+ outputs=[self.prompt, self.negative_prompt],
+ )
+
+ self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
+
+ self.prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[self.prompt_img],
+ outputs=[self.prompt, self.prompt_img],
+ show_progress=False,
+ )
def setup_progressbar(*args, **kwargs):
@@ -415,22 +396,21 @@ def create_ui():
parameters_copypaste.reset()
- modules.scripts.scripts_current = modules.scripts.scripts_txt2img
- modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+ scripts.scripts_current = scripts.scripts_txt2img
+ scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
+ toprow = Toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
- extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img')
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
- modules.scripts.scripts_txt2img.prepare_ui()
+ scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
if category == "sampler":
@@ -498,10 +478,10 @@ def create_ui():
elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+ custom_inputs = scripts.scripts_txt2img.setup_ui()
else:
- modules.scripts.scripts_txt2img.setup_ui_for_section(category)
+ scripts.scripts_txt2img.setup_ui_for_section(category)
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
@@ -532,9 +512,9 @@ def create_ui():
_js="submit",
inputs=[
dummy_component,
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_styles,
+ toprow.prompt,
+ toprow.negative_prompt,
+ toprow.ui_styles.dropdown,
steps,
sampler_index,
restore_faces,
@@ -569,12 +549,12 @@ def create_ui():
show_progress=False,
)
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
+ toprow.prompt.submit(**txt2img_args)
+ toprow.submit.click(**txt2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
- restore_progress_button.click(
+ toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressTxt2img",
inputs=[dummy_component],
@@ -587,18 +567,6 @@ def create_ui():
show_progress=False,
)
- txt_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- txt_prompt_img
- ],
- outputs=[
- txt2img_prompt,
- txt_prompt_img
- ],
- show_progress=False,
- )
-
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
@@ -607,8 +575,8 @@ def create_ui():
)
txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
+ (toprow.prompt, "Prompt"),
+ (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
@@ -621,7 +589,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
- (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
@@ -635,16 +603,16 @@ def create_ui():
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
- *modules.scripts.scripts_txt2img.infotext_fields
+ *scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
+ paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
))
txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
+ toprow.prompt,
+ toprow.negative_prompt,
steps,
sampler_index,
cfg_scale,
@@ -653,22 +621,20 @@ def create_ui():
height,
]
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
- modules.scripts.scripts_current = modules.scripts.scripts_img2img
- modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+ scripts.scripts_current = scripts.scripts_img2img
+ scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
-
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+ toprow = Toprow(is_img2img=True)
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
- extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img')
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
@@ -764,7 +730,7 @@ def create_ui():
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
- modules.scripts.scripts_img2img.prepare_ui()
+ scripts.scripts_img2img.prepare_ui()
for category in ordered_ui_categories():
if category == "sampler":
@@ -845,7 +811,7 @@ def create_ui():
elif category == "scripts":
with FormGroup(elem_id="img2img_script_container"):
- custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+ custom_inputs = scripts.scripts_img2img.setup_ui()
elif category == "inpaint":
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
@@ -876,34 +842,22 @@ def create_ui():
outputs=[inpaint_controls, mask_alpha],
)
else:
- modules.scripts.scripts_img2img.setup_ui_for_section(category)
+ scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- img2img_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- img2img_prompt_img
- ],
- outputs=[
- img2img_prompt,
- img2img_prompt_img
- ],
- show_progress=False,
- )
-
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
inputs=[
dummy_component,
dummy_component,
- img2img_prompt,
- img2img_negative_prompt,
- img2img_prompt_styles,
+ toprow.prompt,
+ toprow.negative_prompt,
+ toprow.ui_styles.dropdown,
init_img,
sketch,
init_img_with_mask,
@@ -962,11 +916,11 @@ def create_ui():
inpaint_color_sketch,
init_img_inpaint,
],
- outputs=[img2img_prompt, dummy_component],
+ outputs=[toprow.prompt, dummy_component],
)
- img2img_prompt.submit(**img2img_args)
- submit.click(**img2img_args)
+ toprow.prompt.submit(**img2img_args)
+ toprow.submit.click(**img2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
@@ -978,7 +932,7 @@ def create_ui():
show_progress=False,
)
- restore_progress_button.click(
+ toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressImg2img",
inputs=[dummy_component],
@@ -991,46 +945,24 @@ def create_ui():
show_progress=False,
)
- img2img_interrogate.click(
+ toprow.button_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args,
)
- img2img_deepbooru.click(
+ toprow.button_deepbooru.click(
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
**interrogate_args,
)
- prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
- style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
-
- for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
- button.click(
- fn=add_style,
- _js="ask_for_style_name",
- # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
- # the same number of parameters, but we only know the style-name after the JavaScript prompt
- inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_styles, img2img_prompt_styles],
- )
-
- for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
- button.click(
- fn=apply_styles,
- _js=js_func,
- inputs=[prompt, negative_prompt, styles],
- outputs=[prompt, negative_prompt, styles],
- )
-
- token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
+ toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
img2img_paste_fields = [
- (img2img_prompt, "Prompt"),
- (img2img_negative_prompt, "Negative prompt"),
+ (toprow.prompt, "Prompt"),
+ (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
@@ -1044,18 +976,18 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
- (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
- *modules.scripts.scripts_img2img.infotext_fields
+ *scripts.scripts_img2img.infotext_fields
]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+ paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
))
- modules.scripts.scripts_current = None
+ scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
ui_postprocessing.create_ui()
@@ -1083,58 +1015,7 @@ def create_ui():
outputs=[html, generation_info, html2],
)
- def update_interp_description(value):
- interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
- interp_descriptions = {
- "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
- "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
- "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
- }
- return interp_descriptions[value]
-
- with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='compact'):
- interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
-
- with FormRow(elem_id="modelmerger_models"):
- primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
- create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
-
- secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
- create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
-
- tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
- create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
-
- custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
- interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
- interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
-
- with FormRow():
- checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
- save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
- save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
-
- with FormRow():
- with gr.Column():
- config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
-
- with gr.Column():
- with FormRow():
- bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
- create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
-
- with FormRow():
- discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
-
- with gr.Row():
- modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
-
- with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
- with gr.Group(elem_id="modelmerger_results_panel"):
- modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
+ modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
@@ -1160,7 +1041,7 @@ def create_ui():
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
- new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
@@ -1305,7 +1186,7 @@ def create_ui():
ti_outcome = gr.HTML(elem_id="ti_error", value="")
create_embedding.click(
- fn=modules.textual_inversion.ui.create_embedding,
+ fn=textual_inversion_ui.create_embedding,
inputs=[
new_embedding_name,
initialization_text,
@@ -1320,7 +1201,7 @@ def create_ui():
)
create_hypernetwork.click(
- fn=modules.hypernetworks.ui.create_hypernetwork,
+ fn=hypernetworks_ui.create_hypernetwork,
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
@@ -1340,7 +1221,7 @@ def create_ui():
)
run_preprocess.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
@@ -1376,7 +1257,7 @@ def create_ui():
)
train_embedding.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
@@ -1410,7 +1291,7 @@ def create_ui():
)
train_hypernetwork.click(
- fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
@@ -1464,7 +1345,7 @@ def create_ui():
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
+ (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "train"),
]
@@ -1516,49 +1397,11 @@ def create_ui():
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
- def modelmerger(*args):
- try:
- results = modules.extras.run_modelmerger(*args)
- except Exception as e:
- errors.report("Error loading/saving model file", exc_info=True)
- modules.sd_models.list_models() # to remove the potentially missing models from the list
- return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
- return results
-
- modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
- modelmerger_merge.click(
- fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
- _js='modelmerger',
- inputs=[
- dummy_component,
- primary_model_name,
- secondary_model_name,
- tertiary_model_name,
- interp_method,
- interp_amount,
- save_as_half,
- custom_name,
- checkpoint_format,
- config_source,
- bake_in_vae,
- discard_weights,
- save_metadata,
- ],
- outputs=[
- primary_model_name,
- secondary_model_name,
- tertiary_model_name,
- settings.component_dict['sd_model_checkpoint'],
- modelmerger_result,
- ]
- )
+ modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
loadsave.dump_defaults()
demo.ui_loadsave = loadsave
- # Required as a workaround for change() event not triggering when loading values from ui-config.json
- interp_description.value = update_interp_description(interp_method.value)
-
return demo
diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py
new file mode 100644
index 00000000..4863d861
--- /dev/null
+++ b/modules/ui_checkpoint_merger.py
@@ -0,0 +1,124 @@
+
+import gradio as gr
+
+from modules import sd_models, sd_vae, errors, extras, call_queue
+from modules.ui_components import FormRow
+from modules.ui_common import create_refresh_button
+
+
+def update_interp_description(value):
+ interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
+ interp_descriptions = {
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
+ }
+ return interp_descriptions[value]
+
+
+def modelmerger(*args):
+ try:
+ results = extras.run_modelmerger(*args)
+ except Exception as e:
+ errors.report("Error loading/saving model file", exc_info=True)
+ sd_models.list_models() # to remove the potentially missing models from the list
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
+ return results
+
+
+class UiCheckpointMerger:
+ def __init__(self):
+ with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='compact'):
+ self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
+
+ with FormRow(elem_id="modelmerger_models"):
+ self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
+
+ self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
+
+ self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
+ create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
+
+ self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
+ self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
+ self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
+
+ with FormRow():
+ self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
+
+ with FormRow():
+ with gr.Column():
+ self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+
+ with gr.Column():
+ with FormRow():
+ self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
+ create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
+
+ with FormRow():
+ self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
+
+ with gr.Accordion("Metadata", open=False) as metadata_editor:
+ with FormRow():
+ self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
+ self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
+ self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
+
+ self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
+ self.read_metadata = gr.Button("Read metadata from selected checkpoints")
+
+ with FormRow():
+ self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+
+ with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
+ with gr.Group(elem_id="modelmerger_results_panel"):
+ self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
+
+ self.metadata_editor = metadata_editor
+ self.blocks = modelmerger_interface
+
+ def setup_ui(self, dummy_component, sd_model_checkpoint_component):
+ self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
+
+ self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
+
+ self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
+ self.modelmerger_merge.click(
+ fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
+ _js='modelmerger',
+ inputs=[
+ dummy_component,
+ self.primary_model_name,
+ self.secondary_model_name,
+ self.tertiary_model_name,
+ self.interp_method,
+ self.interp_amount,
+ self.save_as_half,
+ self.custom_name,
+ self.checkpoint_format,
+ self.config_source,
+ self.bake_in_vae,
+ self.discard_weights,
+ self.save_metadata,
+ self.add_merge_recipe,
+ self.copy_metadata_fields,
+ self.metadata_json,
+ ],
+ outputs=[
+ self.primary_model_name,
+ self.secondary_model_name,
+ self.tertiary_model_name,
+ sd_model_checkpoint_component,
+ self.modelmerger_result,
+ ]
+ )
+
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ self.interp_description.value = update_interp_description(self.interp_method.value)
+
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 11eb2a4b..ba75fa73 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -223,20 +223,44 @@ Requested path was: {f}
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
+
+ label = None
+ for comp in refresh_components:
+ label = getattr(comp, 'label', None)
+ if label is not None:
+ break
+
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
- setattr(refresh_component, k, v)
+ for comp in refresh_components:
+ setattr(comp, k, v)
- return gr.update(**(args or {}))
+ return [gr.update(**(args or {})) for _ in refresh_components]
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
refresh_button.click(
fn=refresh,
inputs=[],
- outputs=[refresh_component]
+ outputs=[*refresh_components]
)
return refresh_button
+
+def setup_dialog(button_show, dialog, *, button_close=None):
+ """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
+
+ dialog.visible = False
+
+ button_show.click(
+ fn=lambda: gr.update(visible=True),
+ inputs=[],
+ outputs=[dialog],
+ ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
+
+ if button_close:
+ button_close.click(fn=None, _js="closePopup")
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 76780cfd..891d8f2c 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -12,7 +12,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.refresh_checkpoints()
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
path, ext = os.path.splitext(checkpoint.filename)
return {
@@ -23,6 +23,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
"local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": checkpoint.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
}
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index e53ccb42..514a4562 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.reload_hypernetworks()
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path)
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index d1794e50..73134698 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
path, ext = os.path.splitext(embedding.filename)
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
index 63d4b503..1cb9eb6f 100644
--- a/modules/ui_extra_networks_user_metadata.py
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -96,6 +96,7 @@ class UserMetadataEditor:
stats = os.stat(filename)
params = [
+ ('Filename: ', os.path.basename(filename)),
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
]
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
new file mode 100644
index 00000000..85eb3a64
--- /dev/null
+++ b/modules/ui_prompt_styles.py
@@ -0,0 +1,110 @@
+import gradio as gr
+
+from modules import shared, ui_common, ui_components, styles
+
+styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
+styles_materialize_symbol = '\U0001f4cb' # 📋
+
+
+def select_style(name):
+ style = shared.prompt_styles.styles.get(name)
+ existing = style is not None
+ empty = not name
+
+ prompt = style.prompt if style else gr.update()
+ negative_prompt = style.negative_prompt if style else gr.update()
+
+ return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
+
+
+def save_style(name, prompt, negative_prompt):
+ if not name:
+ return gr.update(visible=False)
+
+ style = styles.PromptStyle(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return gr.update(visible=True)
+
+
+def delete_style(name):
+ if name == "":
+ return
+
+ shared.prompt_styles.styles.pop(name, None)
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return '', '', ''
+
+
+def materialize_styles(prompt, negative_prompt, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
+
+
+def refresh_styles():
+ return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
+
+
+class UiPromptStyles:
+ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
+ self.tabname = tabname
+
+ with gr.Row(elem_id=f"{tabname}_styles_row"):
+ self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
+ edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
+
+ with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
+ with gr.Row():
+ self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
+ ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
+ self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+
+ with gr.Row():
+ self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
+
+ with gr.Row():
+ self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
+
+ with gr.Row():
+ self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
+ self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
+ self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
+
+ self.selection.change(
+ fn=select_style,
+ inputs=[self.selection],
+ outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
+ show_progress=False,
+ )
+
+ self.save.click(
+ fn=save_style,
+ inputs=[self.selection, self.prompt, self.neg_prompt],
+ outputs=[self.delete],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.delete.click(
+ fn=delete_style,
+ _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
+ inputs=[self.selection],
+ outputs=[self.selection, self.prompt, self.neg_prompt],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.materialize.click(
+ fn=materialize_styles,
+ inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ show_progress=False,
+ ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+
+ ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
+
+
+
+
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 1010845e..103bf104 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -3,6 +3,7 @@ from copy import copy
from itertools import permutations, chain
import random
import csv
+import os.path
from io import StringIO
from PIL import Image
import numpy as np
@@ -10,7 +11,7 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
+from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
@@ -182,6 +183,8 @@ def do_nothing(p, x, xs):
def format_nothing(p, opt, x):
return ""
+def format_remove_path(p, opt, x):
+ return os.path.basename(x)
def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
@@ -223,7 +226,7 @@ axis_options = [
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")),
@@ -648,7 +651,12 @@ class Script(scripts.Script):
y_opt.apply(pc, y, ys)
z_opt.apply(pc, z, zs)
- res = process_images(pc)
+ try:
+ res = process_images(pc)
+ except Exception as e:
+ errors.display(e, "generating image for xyz plot")
+
+ res = Processed(p, [], p.seed, "")
# Sets subgrid infotexts
subgrid_index = 1 + iz
diff --git a/style.css b/style.css
index 6c92d6e7..cf8470e4 100644
--- a/style.css
+++ b/style.css
@@ -972,3 +972,16 @@ div.block.gradio-box.edit-user-metadata {
.edit-user-metadata-buttons{
margin-top: 1.5em;
}
+
+
+
+
+div.block.gradio-box.popup-dialog, .popup-dialog {
+ width: 56em;
+ background: var(--body-background-fill);
+ padding: 2em !important;
+}
+
+div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{
+ margin-top: 1em;
+}
diff --git a/webui.py b/webui.py
index 6bf06854..8d84e5a4 100644
--- a/webui.py
+++ b/webui.py
@@ -58,10 +58,10 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
-import modules.face_restoration
import modules.gfpgan_model as gfpgan
+from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+import modules.face_restoration
import modules.img2img
import modules.lowvram
@@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False):
if modules.sd_hijack.current_optimizer is None:
modules.sd_hijack.apply_optimizations()
- Thread(target=load_model).start()
+ devices.first_time_calculation()
- Thread(target=devices.first_time_calculation).start()
+ Thread(target=load_model).start()
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")