aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/devices.py83
-rw-r--r--modules/errors.py50
-rw-r--r--modules/extensions.py10
-rw-r--r--modules/extra_networks.py19
-rw-r--r--modules/generation_parameters_copypaste.py3
-rw-r--r--modules/gradio_extensons.py60
-rw-r--r--modules/hypernetworks/hypernetwork.py5
-rw-r--r--modules/images.py2
-rw-r--r--modules/img2img.py6
-rw-r--r--modules/processing.py111
-rw-r--r--modules/prompt_parser.py9
-rw-r--r--modules/rng_philox.py102
-rw-r--r--modules/scripts.py60
-rw-r--r--modules/sd_hijack.py6
-rw-r--r--modules/sd_hijack_clip.py2
-rw-r--r--modules/sd_hijack_optimizations.py4
-rw-r--r--modules/sd_models.py29
-rw-r--r--modules/sd_samplers_common.py20
-rw-r--r--modules/sd_samplers_kdiffusion.py9
-rw-r--r--modules/sd_vae.py16
-rw-r--r--modules/shared.py53
-rw-r--r--modules/styles.py5
-rw-r--r--modules/textual_inversion/textual_inversion.py4
-rw-r--r--modules/txt2img.py3
-rw-r--r--modules/ui.py352
-rw-r--r--modules/ui_checkpoint_merger.py2
-rw-r--r--modules/ui_common.py34
-rw-r--r--modules/ui_components.py2
-rw-r--r--modules/ui_extensions.py26
-rw-r--r--modules/ui_extra_networks.py13
-rw-r--r--modules/ui_extra_networks_checkpoints.py5
-rw-r--r--modules/ui_extra_networks_checkpoints_user_metadata.py60
-rw-r--r--modules/ui_extra_networks_hypernets.py2
-rw-r--r--modules/ui_extra_networks_textual_inversion.py2
-rw-r--r--modules/ui_postprocessing.py2
-rw-r--r--modules/ui_prompt_styles.py110
-rw-r--r--modules/ui_settings.py2
38 files changed, 871 insertions, 414 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index cb4ec5f7..64f21e01 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -112,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
+parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
diff --git a/modules/devices.py b/modules/devices.py
index 57e51da3..00a00b18 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors
+from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True
-
errors.run(enable_tf32, "Enabling TF32")
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -90,23 +93,87 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
+nv_rng = None
+
+
def randn(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
from modules.shared import opts
- torch.manual_seed(seed)
+ manual_seed(seed)
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def randn_local(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ rng = rng_philox.Generator(seed)
+ return torch.asarray(rng.randn(shape), device=device)
+
+ local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+ local_generator = torch.Generator(local_device).manual_seed(int(seed))
+ return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
+def randn_like(x):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
+ return torch.randn_like(x, device=cpu).to(x.device)
+
+ return torch.randn_like(x)
+
+
def randn_without_seed(shape):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
from modules.shared import opts
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def manual_seed(seed):
+ """Set up a global random number generator using the specified seed."""
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ global nv_rng
+ nv_rng = rng_philox.Generator(seed)
+ return
+
+ torch.manual_seed(seed)
+
+
def autocast(disable=False):
from modules import shared
diff --git a/modules/errors.py b/modules/errors.py
index dffabe45..192cd8ff 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -84,3 +84,53 @@ def run(code, task):
code()
except Exception as e:
display(task, e)
+
+
+def check_versions():
+ from packaging import version
+ from modules import shared
+
+ import torch
+ import gradio
+
+ expected_torch_version = "2.0.0"
+ expected_xformers_version = "0.0.20"
+ expected_gradio_version = "3.39.0"
+
+ if version.parse(torch.__version__) < version.parse(expected_torch_version):
+ print_error_explanation(f"""
+You are running torch {torch.__version__}.
+The program is tested to work with torch {expected_torch_version}.
+To reinstall the desired version, run with commandline flag --reinstall-torch.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if shared.xformers_available:
+ import xformers
+
+ if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
+ print_error_explanation(f"""
+You are running xformers {xformers.__version__}.
+The program is tested to work with xformers {expected_xformers_version}.
+To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if gradio.__version__ != expected_gradio_version:
+ print_error_explanation(f"""
+You are running gradio {gradio.__version__}.
+The program is designed to work with gradio {expected_gradio_version}.
+Using a different version of gradio is extremely likely to break the program.
+
+Reasons why you have the mismatched gradio version can be:
+ - you use --skip-install flag.
+ - you use webui.py to start the program instead of launch.py.
+ - an extension installs the incompatible gradio version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
diff --git a/modules/extensions.py b/modules/extensions.py
index 3ad5ed53..e4633af4 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
def active():
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
return []
- elif shared.opts.disable_all_extensions == "extra":
+ elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled]
@@ -141,8 +141,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir):
return
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions:
+ print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
+ elif shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
+ elif shared.cmd_opts.disable_extra_extensions:
+ print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 6ae07e91..fa28ac75 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -1,3 +1,5 @@
+import json
+import os
import re
from collections import defaultdict
@@ -177,3 +179,20 @@ def parse_prompts(prompts):
return res, extra_data
+
+def get_user_metadata(filename):
+ if filename is None:
+ return {}
+
+ basename, ext = os.path.splitext(filename)
+ metadata_filename = basename + '.json'
+
+ metadata = {}
+ try:
+ if os.path.isfile(metadata_filename):
+ with open(metadata_filename, "r", encoding="utf8") as file:
+ metadata = json.load(file)
+ except Exception as e:
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+ return metadata
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index a3448be9..4e286558 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Hires sampler" not in res:
res["Hires sampler"] = "Use same sampler"
+ if "Hires checkpoint" not in res:
+ res["Hires checkpoint"] = "Use same checkpoint"
+
if "Hires prompt" not in res:
res["Hires prompt"] = ""
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
new file mode 100644
index 00000000..5af7fd8e
--- /dev/null
+++ b/modules/gradio_extensons.py
@@ -0,0 +1,60 @@
+import gradio as gr
+
+from modules import scripts
+
+def add_classes_to_gradio_component(comp):
+ """
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+ """
+
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
+
+ if getattr(comp, 'multiselect', False):
+ comp.elem_classes.append('multiselect')
+
+
+def IOComponent_init(self, *args, **kwargs):
+ self.webui_tooltip = kwargs.pop('tooltip', None)
+
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.before_component(self, **kwargs)
+
+ scripts.script_callbacks.before_component_callback(self, **kwargs)
+
+ res = original_IOComponent_init(self, *args, **kwargs)
+
+ add_classes_to_gradio_component(self)
+
+ scripts.script_callbacks.after_component_callback(self, **kwargs)
+
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.after_component(self, **kwargs)
+
+ return res
+
+
+def Block_get_config(self):
+ config = original_Block_get_config(self)
+
+ webui_tooltip = getattr(self, 'webui_tooltip', None)
+ if webui_tooltip:
+ config["webui_tooltip"] = webui_tooltip
+
+ return config
+
+
+def BlockContext_init(self, *args, **kwargs):
+ res = original_BlockContext_init(self, *args, **kwargs)
+
+ add_classes_to_gradio_component(self)
+
+ return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+original_Block_get_config = gr.blocks.Block.get_config
+original_BlockContext_init = gr.blocks.BlockContext.__init__
+
+gr.components.IOComponent.__init__ = IOComponent_init
+gr.blocks.Block.get_config = Block_get_config
+gr.blocks.BlockContext.__init__ = BlockContext_init
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c4821d21..70f1cbd2 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -10,7 +10,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
- from modules import images
+ from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
diff --git a/modules/images.py b/modules/images.py
index 38aa933d..ba3c43a4 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
return res
-invalid_filename_chars = '<>:"/\\|?*\n'
+invalid_filename_chars = '<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
diff --git a/modules/img2img.py b/modules/img2img.py
index 68e415ef..d8e1c534 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -3,7 +3,7 @@ from contextlib import closing
from pathlib import Path
import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr
from modules import sd_samplers, images as imgutil
@@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
mask = None
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask).convert('L')
+ mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
diff --git a/modules/processing.py b/modules/processing.py
index b0992ee1..ae58b108 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+decode_first_stage = sd_samplers_common.decode_first_stage
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -492,7 +493,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
subnoise = None
- if subseeds is not None:
+ if subseeds is not None and subseed_strength != 0:
subseed = 0 if i >= len(subseeds) else subseeds[i]
subnoise = devices.randn(subseed, noise_shape)
@@ -524,7 +525,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
cnt = p.sampler.number_of_needed_noises(p)
if eta_noise_seed_delta > 0:
- torch.manual_seed(seed + eta_noise_seed_delta)
+ devices.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -538,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+class DecodedSamples(list):
+ already_decoded = True
+
+
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
- samples = []
+ samples = DecodedSamples()
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -572,12 +577,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
return samples
-def decode_first_stage(model, x):
- x = model.decode_first_stage(x.to(devices.dtype_vae))
-
- return x
-
-
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
@@ -636,7 +635,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
- "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+ "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
@@ -793,7 +792,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ if getattr(samples_ddim, 'already_decoded', False):
+ x_samples_ddim = samples_ddim
+ else:
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -935,7 +938,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -946,11 +949,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_checkpoint_name = hr_checkpoint_name
+ self.hr_checkpoint_info = None
self.hr_sampler_name = hr_sampler_name
self.hr_prompt = hr_prompt
self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
+ self.latent_scale_mode = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -973,6 +979,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_checkpoint_name:
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+ if self.hr_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
@@ -982,6 +996,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and self.latent_scale_mode is None:
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
self.hr_resize_y = self.height
@@ -1020,14 +1039,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
- # special case: the user has chosen to do nothing
- if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
- self.enable_hr = False
- self.denoising_strength = None
- self.extra_generation_params.pop("Hires upscale", None)
- self.extra_generation_params.pop("Hires resize", None)
- return
-
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
@@ -1045,17 +1056,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
- if self.enable_hr and latent_scale_mode is None:
- if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
- raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ del x
if not self.enable_hr:
return samples
+ if self.latent_scale_mode is None:
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+ else:
+ decoded_samples = None
+
+ current = shared.sd_model.sd_checkpoint_info
+ try:
+ if self.hr_checkpoint_info is not None:
+ self.sampler = None
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+ devices.torch_gc()
+
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
+ finally:
+ self.sampler = None
+ sd_models.reload_model_weights(info=current)
+ devices.torch_gc()
+
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
@@ -1073,11 +1099,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
- if latent_scale_mode is not None:
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
+ if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+ img2img_sampler_name = 'DDIM'
+
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+
+ if self.latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -1086,7 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = []
@@ -1105,6 +1137,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device)
decoded_samples = 2. * decoded_samples - 1.
+ decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
@@ -1112,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- img2img_sampler_name = self.hr_sampler_name or self.sampler_name
-
- if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
- img2img_sampler_name = 'DDIM'
-
- self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
- x = None
devices.torch_gc()
if not self.disable_extra_networks:
@@ -1143,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+
self.is_hr_pass = False
- return samples
+ return decoded_samples
def close(self):
super().close()
@@ -1184,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_c is not None:
return
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+ hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
+ hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self):
super().setup_conds()
@@ -1193,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = None
self.hr_c = None
- if self.enable_hr:
+ if self.enable_hr and self.hr_checkpoint_info is None:
if shared.opts.hires_fix_use_firstpass_conds:
self.calculate_hr_conds()
@@ -1348,6 +1378,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = image.to(shared.device, dtype=devices.dtype_vae)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ devices.torch_gc()
if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 8169a459..32d214e3 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -20,7 +20,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
| "(" prompt ":" prompt ")"
| "[" prompt "]"
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
-alternate: "[" prompt ("|" prompt)+ "]"
+alternate: "[" prompt ("|" [prompt])+ "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
@@ -53,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
>>> g("[a|(b:1.1)]")
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
+ >>> g("[fe|]male")
+ [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
+ >>> g("[fe|||]male")
+ [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
"""
def collect_steps(steps, tree):
@@ -78,7 +82,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
before, after, _, when, _ = args
yield before or () if step <= when else after
def alternate(self, args):
- yield next(args[(step - 1)%len(args)])
+ args = ["" if not arg else arg for arg in args]
+ yield args[(step - 1) % len(args)]
def start(self, args):
def flatten(x):
if type(x) == str:
diff --git a/modules/rng_philox.py b/modules/rng_philox.py
new file mode 100644
index 00000000..5532cf9d
--- /dev/null
+++ b/modules/rng_philox.py
@@ -0,0 +1,102 @@
+"""RNG imitiating torch cuda randn on CPU. You are welcome.
+
+Usage:
+
+```
+g = Generator(seed=0)
+print(g.randn(shape=(3, 4)))
+```
+
+Expected output:
+```
+[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
+ [-0.12086647 -0.57972564 -0.62285122 -0.32838709]
+ [-1.07454231 -0.36314407 -1.67105067 2.26550497]]
+```
+"""
+
+import numpy as np
+
+philox_m = [0xD2511F53, 0xCD9E8D57]
+philox_w = [0x9E3779B9, 0xBB67AE85]
+
+two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
+two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
+
+
+def uint32(x):
+ """Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
+ return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
+
+
+def philox4_round(counter, key):
+ """A single round of the Philox 4x32 random number generator."""
+
+ v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
+ v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
+
+ counter[0] = v2[1] ^ counter[1] ^ key[0]
+ counter[1] = v2[0]
+ counter[2] = v1[1] ^ counter[3] ^ key[1]
+ counter[3] = v1[0]
+
+
+def philox4_32(counter, key, rounds=10):
+ """Generates 32-bit random numbers using the Philox 4x32 random number generator.
+
+ Parameters:
+ counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
+ key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
+ rounds (int): The number of rounds to perform.
+
+ Returns:
+ numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
+ """
+
+ for _ in range(rounds - 1):
+ philox4_round(counter, key)
+
+ key[0] = key[0] + philox_w[0]
+ key[1] = key[1] + philox_w[1]
+
+ philox4_round(counter, key)
+ return counter
+
+
+def box_muller(x, y):
+ """Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
+ u = x * two_pow32_inv + two_pow32_inv / 2
+ v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
+
+ s = np.sqrt(-2.0 * np.log(u))
+
+ r1 = s * np.sin(v)
+ return r1.astype(np.float32)
+
+
+class Generator:
+ """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
+
+ def __init__(self, seed):
+ self.seed = seed
+ self.offset = 0
+
+ def randn(self, shape):
+ """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
+
+ n = 1
+ for x in shape:
+ n *= x
+
+ counter = np.zeros((4, n), dtype=np.uint32)
+ counter[0] = self.offset
+ counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
+ self.offset += 1
+
+ key = np.empty(n, dtype=np.uint64)
+ key.fill(self.seed)
+ key = uint32(key)
+
+ g = philox4_32(counter, key)
+
+ return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
diff --git a/modules/scripts.py b/modules/scripts.py
index edf7347e..f7d060aa 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -631,63 +631,3 @@ def reload_script_body_only():
reload_scripts = load_scripts # compatibility alias
-
-
-def add_classes_to_gradio_component(comp):
- """
- this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
- """
-
- comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
-
- if getattr(comp, 'multiselect', False):
- comp.elem_classes.append('multiselect')
-
-
-
-def IOComponent_init(self, *args, **kwargs):
- self.webui_tooltip = kwargs.pop('tooltip', None)
-
- if scripts_current is not None:
- scripts_current.before_component(self, **kwargs)
-
- script_callbacks.before_component_callback(self, **kwargs)
-
- res = original_IOComponent_init(self, *args, **kwargs)
-
- add_classes_to_gradio_component(self)
-
- script_callbacks.after_component_callback(self, **kwargs)
-
- if scripts_current is not None:
- scripts_current.after_component(self, **kwargs)
-
- return res
-
-
-def Block_get_config(self):
- config = original_Block_get_config(self)
-
- webui_tooltip = getattr(self, 'webui_tooltip', None)
- if webui_tooltip:
- config["webui_tooltip"] = webui_tooltip
-
- return config
-
-
-original_IOComponent_init = gr.components.IOComponent.__init__
-original_Block_get_config = gr.components.Block.get_config
-gr.components.IOComponent.__init__ = IOComponent_init
-gr.components.Block.get_config = Block_get_config
-
-
-def BlockContext_init(self, *args, **kwargs):
- res = original_BlockContext_init(self, *args, **kwargs)
-
- add_classes_to_gradio_component(self)
-
- return res
-
-
-original_BlockContext_init = gr.blocks.BlockContext.__init__
-gr.blocks.BlockContext.__init__ = BlockContext_init
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 9722c967..9ad98199 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -2,7 +2,6 @@ import torch
from torch.nn.functional import silu
from types import MethodType
-import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
@@ -168,12 +167,13 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
def __init__(self):
+ import modules.textual_inversion.textual_inversion
+
self.extra_generation_params = {}
self.comments = []
+ self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 2f9d569b..8f29057a 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
hashes.append(f"{name}: {shorthash}")
if hashes:
+ if self.hijack.extra_generation_params.get("TI hashes"):
+ hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
if getattr(self.wrapped, 'return_pooled', False):
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index b5f85ba5..0e810eec 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ slice_size = q.shape[1] // steps
for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
+ end = min(i + slice_size, q.shape[1])
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3c451a4b..f6051604 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -66,8 +66,9 @@ class CheckpointInfo:
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -86,6 +87,7 @@ class CheckpointInfo:
checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
self.register()
return self.shorthash
@@ -106,14 +108,8 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
-
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
-
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
def list_models():
@@ -136,11 +132,14 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
- for filename in sorted(model_list, key=str.lower):
+ for filename in model_list:
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
@@ -150,6 +149,11 @@ def get_closet_checkpoint_match(search_string):
if found:
return found[0]
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+
return None
@@ -302,12 +306,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False)
- del state_dict
timer.record("apply weights to model")
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ checkpoints_loaded[checkpoint_info] = state_dict
+
+ del state_dict
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 763829f1..b3d344e7 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,10 +2,8 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
-
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
-import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -37,7 +35,7 @@ def single_sample_to_image(sample, approximation=None):
x_sample = sample * 1.5
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
+ x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -46,6 +44,12 @@ def single_sample_to_image(sample, approximation=None):
return Image.fromarray(x_sample)
+def decode_first_stage(model, x):
+ x = model.decode_first_stage(x.to(devices.dtype_vae))
+
+ return x
+
+
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
@@ -85,11 +89,13 @@ class InterruptedException(BaseException):
pass
-if opts.randn_source == "CPU":
+def replace_torchsde_browinan():
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
- generator = torch.Generator(devices.cpu).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ return devices.randn_local(seed, size).to(device=device, dtype=dtype)
torchsde._brownian.brownian_interval._randn = torchsde_randn
+
+
+replace_torchsde_browinan()
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index e0da3425..8bb639f5 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -30,6 +30,7 @@ samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
+ ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
]
@@ -260,10 +261,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
- if opts.randn_source == "CPU" or x.device.type == 'mps':
- return torch.randn_like(x, device=devices.cpu).to(x.device)
- else:
- return torch.randn_like(x)
+ return devices.randn_like(x)
class KDiffusionSampler:
@@ -378,6 +376,9 @@ class KDiffusionSampler:
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index e4ff2994..0bd5e19b 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,6 +1,6 @@
import os
import collections
-from modules import paths, shared, devices, script_callbacks, sd_models
+from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
import glob
from copy import deepcopy
@@ -16,6 +16,7 @@ checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
+
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
@@ -50,6 +51,7 @@ def get_filename(filepath):
def refresh_vae_list():
+ global vae_dict
vae_dict.clear()
paths = [
@@ -83,6 +85,8 @@ def refresh_vae_list():
name = get_filename(filepath)
vae_dict[name] = filepath
+ vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
+
def find_vae_near_checkpoint(checkpoint_file):
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
@@ -97,6 +101,16 @@ def resolve_vae(checkpoint_file):
if shared.cmd_opts.vae_path is not None:
return shared.cmd_opts.vae_path, 'from commandline argument'
+ metadata = extra_networks.get_user_metadata(checkpoint_file)
+ vae_metadata = metadata.get("vae", None)
+ if vae_metadata is not None and vae_metadata != "Automatic":
+ if vae_metadata == "None":
+ return None, None
+
+ vae_from_metadata = vae_dict.get(vae_metadata, None)
+ if vae_from_metadata is not None:
+ return vae_from_metadata, "from user metadata"
+
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
diff --git a/modules/shared.py b/modules/shared.py
index 0184fcd0..8245250a 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -220,12 +220,19 @@ class State:
return
import modules.sd_samplers
- if opts.show_progress_grid:
- self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
- else:
- self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
- self.current_image_sampling_step = self.sampling_step
+ try:
+ if opts.show_progress_grid:
+ self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
+ else:
+ self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
+
+ self.current_image_sampling_step = self.sampling_step
+
+ except Exception:
+ # when switching models during genration, VAE would be on CPU, so creating an image will fail.
+ # we silently ignore this error
+ errors.record_exception()
def assign_current_image(self, image):
self.current_image = image
@@ -385,7 +392,8 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
}))
options_templates.update(options_section(('system', "System"), {
- "show_warnings": OptionInfo(False, "Show warnings in console."),
+ "show_warnings": OptionInfo(False, "Show warnings in console.").needs_restart(),
+ "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_restart(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
@@ -419,19 +427,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
- "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
- "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
- "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
- "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
+ "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_restart(),
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
- "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
@@ -441,6 +444,22 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
+
+options_templates.update(options_section(('img2img', "img2img"), {
+ "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
+ "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
+ "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
+ "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
+ "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
+ "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_restart(),
+ "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_restart(),
+ "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_restart(),
+ "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
+ "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
+}))
+
+
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
@@ -460,7 +479,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
}))
-options_templates.update(options_section(('interrogate', "Interrogate Options"), {
+options_templates.update(options_section(('interrogate', "Interrogate"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
@@ -493,10 +512,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
- "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
"return_grid": OptionInfo(True, "Show grid in results for web"),
- "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
- "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
@@ -515,11 +531,12 @@ options_templates.update(options_section(('ui', "User interface"), {
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
- "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
+ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
}))
+
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
diff --git a/modules/styles.py b/modules/styles.py
index ec0e1bc5..0740fe1b 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -106,10 +106,7 @@ class StyleDatabase:
if os.path.exists(path):
shutil.copy(path, f"{path}.bak")
- fd = os.open(path, os.O_RDWR | os.O_CREAT)
- with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
- # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
- # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+ with open(path, "w", encoding="utf-8-sig", newline='') as file:
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items())
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4713bc2d..aa79dc09 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -13,7 +13,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
+from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ from modules import processing
+
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 29d94e8c..935ed418 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_second_pass_steps=hr_second_pass_steps,
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
+ hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt,
diff --git a/modules/ui.py b/modules/ui.py
index ac2787eb..c49538b6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -12,34 +12,30 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger
+from modules import gradio_extensons # noqa: F401
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
from modules.ui_gradio_extensions import reload_javascript
-
from modules.shared import opts, cmd_opts
-import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
-import modules.gfpgan_model
-import modules.hypernetworks.ui
-import modules.scripts
+import modules.hypernetworks.ui as hypernetworks_ui
+import modules.textual_inversion.ui as textual_inversion_ui
+import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
-import modules.styles
-import modules.textual_inversion.ui
+import modules.images
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.textual_inversion import textual_inversion
-import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
-import modules.extras
create_setting_component = ui_settings.create_setting_component
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -92,19 +88,6 @@ def send_gradio_gallery_to_image(x):
return image_from_url_text(x[0])
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
-
-
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
from modules import processing, devices
@@ -129,13 +112,6 @@ def resize_from_to_html(width, height, scale_by):
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
-def apply_styles(prompt, prompt_neg, styles):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
-
-
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
if mode in {0, 1, 3, 4}:
return [interrogation_function(ii_singles[mode]), None]
@@ -172,7 +148,6 @@ def interrogate_deepbooru(image):
def create_seed_inputs(target_interface):
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
- seed.style(container=False)
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
@@ -184,7 +159,6 @@ def create_seed_inputs(target_interface):
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
- subseed.style(container=False)
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
@@ -267,71 +241,78 @@ def update_token_counter(text, steps):
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
+class Toprow:
+ """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
- with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
- with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ def __init__(self, is_img2img):
+ id_part = "img2img" if is_img2img else "txt2img"
+ self.id_part = id_part
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_classes="interrogate-col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
- with gr.Row(elem_id=f"{id_part}_tools"):
- paste = ToolButton(value=paste_symbol, elem_id="paste")
- clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
- prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
- save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
- restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
-
- token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
- negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
- with gr.Row(elem_id=f"{id_part}_styles_row"):
- prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
- create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
+ self.button_interrogate = None
+ self.button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_classes="interrogate-col"):
+ self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
+ with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
+ with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+ self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
+ self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
+ self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+ self.skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ self.interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ with gr.Row(elem_id=f"{id_part}_tools"):
+ self.paste = ToolButton(value=paste_symbol, elem_id="paste")
+ self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+ self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
+
+ self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
+ self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
+ self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+ self.clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[self.prompt, self.negative_prompt],
+ outputs=[self.prompt, self.negative_prompt],
+ )
+
+ self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
+
+ self.prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[self.prompt_img],
+ outputs=[self.prompt, self.prompt_img],
+ show_progress=False,
+ )
def setup_progressbar(*args, **kwargs):
@@ -415,22 +396,21 @@ def create_ui():
parameters_copypaste.reset()
- modules.scripts.scripts_current = modules.scripts.scripts_txt2img
- modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+ scripts.scripts_current = scripts.scripts_txt2img
+ scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
+ toprow = Toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
- extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img')
- with gr.Row().style(equal_height=False):
+ with gr.Row(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
- modules.scripts.scripts_txt2img.prepare_ui()
+ scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
if category == "sampler":
@@ -476,6 +456,10 @@ def create_ui():
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
+
+ hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
+ create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
+
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
@@ -498,10 +482,10 @@ def create_ui():
elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+ custom_inputs = scripts.scripts_txt2img.setup_ui()
else:
- modules.scripts.scripts_txt2img.setup_ui_for_section(category)
+ scripts.scripts_txt2img.setup_ui_for_section(category)
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
@@ -532,9 +516,9 @@ def create_ui():
_js="submit",
inputs=[
dummy_component,
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_styles,
+ toprow.prompt,
+ toprow.negative_prompt,
+ toprow.ui_styles.dropdown,
steps,
sampler_index,
restore_faces,
@@ -553,6 +537,7 @@ def create_ui():
hr_second_pass_steps,
hr_resize_x,
hr_resize_y,
+ hr_checkpoint_name,
hr_sampler_index,
hr_prompt,
hr_negative_prompt,
@@ -569,12 +554,12 @@ def create_ui():
show_progress=False,
)
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
+ toprow.prompt.submit(**txt2img_args)
+ toprow.submit.click(**txt2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
- restore_progress_button.click(
+ toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressTxt2img",
inputs=[dummy_component],
@@ -587,18 +572,6 @@ def create_ui():
show_progress=False,
)
- txt_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- txt_prompt_img
- ],
- outputs=[
- txt2img_prompt,
- txt_prompt_img
- ],
- show_progress=False,
- )
-
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
@@ -607,8 +580,8 @@ def create_ui():
)
txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
+ (toprow.prompt, "Prompt"),
+ (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
@@ -617,34 +590,36 @@ def create_ui():
(width, "Size-1"),
(height, "Size-2"),
(batch_size, "Batch size"),
+ (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
- (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d),
- (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
+ (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
(hr_scale, "Hires upscale"),
(hr_upscaler, "Hires upscaler"),
(hr_second_pass_steps, "Hires steps"),
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
+ (hr_checkpoint_name, "Hires checkpoint"),
(hr_sampler_index, "Hires sampler"),
- (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
+ (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
- *modules.scripts.scripts_txt2img.infotext_fields
+ *scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
+ paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
))
txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
+ toprow.prompt,
+ toprow.negative_prompt,
steps,
sampler_index,
cfg_scale,
@@ -653,24 +628,22 @@ def create_ui():
height,
]
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
- modules.scripts.scripts_current = modules.scripts.scripts_img2img
- modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+ scripts.scripts_current = scripts.scripts_img2img
+ scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
-
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+ toprow = Toprow(is_img2img=True)
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
- extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img')
- with FormRow().style(equal_height=False):
+ with FormRow(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
copy_image_destinations = {}
@@ -692,19 +665,19 @@ def create_ui():
img2img_selected_tab = gr.State(0)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
@@ -764,7 +737,7 @@ def create_ui():
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
- modules.scripts.scripts_img2img.prepare_ui()
+ scripts.scripts_img2img.prepare_ui()
for category in ordered_ui_categories():
if category == "sampler":
@@ -845,7 +818,7 @@ def create_ui():
elif category == "scripts":
with FormGroup(elem_id="img2img_script_container"):
- custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+ custom_inputs = scripts.scripts_img2img.setup_ui()
elif category == "inpaint":
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
@@ -876,34 +849,22 @@ def create_ui():
outputs=[inpaint_controls, mask_alpha],
)
else:
- modules.scripts.scripts_img2img.setup_ui_for_section(category)
+ scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- img2img_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- img2img_prompt_img
- ],
- outputs=[
- img2img_prompt,
- img2img_prompt_img
- ],
- show_progress=False,
- )
-
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
inputs=[
dummy_component,
dummy_component,
- img2img_prompt,
- img2img_negative_prompt,
- img2img_prompt_styles,
+ toprow.prompt,
+ toprow.negative_prompt,
+ toprow.ui_styles.dropdown,
init_img,
sketch,
init_img_with_mask,
@@ -962,11 +923,11 @@ def create_ui():
inpaint_color_sketch,
init_img_inpaint,
],
- outputs=[img2img_prompt, dummy_component],
+ outputs=[toprow.prompt, dummy_component],
)
- img2img_prompt.submit(**img2img_args)
- submit.click(**img2img_args)
+ toprow.prompt.submit(**img2img_args)
+ toprow.submit.click(**img2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
@@ -978,7 +939,7 @@ def create_ui():
show_progress=False,
)
- restore_progress_button.click(
+ toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressImg2img",
inputs=[dummy_component],
@@ -991,46 +952,24 @@ def create_ui():
show_progress=False,
)
- img2img_interrogate.click(
+ toprow.button_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args,
)
- img2img_deepbooru.click(
+ toprow.button_deepbooru.click(
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
**interrogate_args,
)
- prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
- style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
-
- for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
- button.click(
- fn=add_style,
- _js="ask_for_style_name",
- # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
- # the same number of parameters, but we only know the style-name after the JavaScript prompt
- inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_styles, img2img_prompt_styles],
- )
-
- for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
- button.click(
- fn=apply_styles,
- _js=js_func,
- inputs=[prompt, negative_prompt, styles],
- outputs=[prompt, negative_prompt, styles],
- )
-
- token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
+ toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
img2img_paste_fields = [
- (img2img_prompt, "Prompt"),
- (img2img_negative_prompt, "Negative prompt"),
+ (toprow.prompt, "Prompt"),
+ (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
@@ -1040,28 +979,29 @@ def create_ui():
(width, "Size-1"),
(height, "Size-2"),
(batch_size, "Batch size"),
+ (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
- (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
- *modules.scripts.scripts_img2img.infotext_fields
+ *scripts.scripts_img2img.infotext_fields
]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+ paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
))
- modules.scripts.scripts_current = None
+ scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
ui_postprocessing.create_ui()
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
- with gr.Row().style(equal_height=False):
+ with gr.Row(equal_height=False):
with gr.Column(variant='panel'):
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
@@ -1086,10 +1026,10 @@ def create_ui():
modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
with gr.Blocks(analytics_enabled=False) as train_interface:
- with gr.Row().style(equal_height=False):
+ with gr.Row(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
- with gr.Row(variant="compact").style(equal_height=False):
+ with gr.Row(variant="compact", equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding", id="create_embedding"):
@@ -1109,7 +1049,7 @@ def create_ui():
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
- new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
@@ -1249,12 +1189,12 @@ def create_ui():
with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
- gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
+ gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
create_embedding.click(
- fn=modules.textual_inversion.ui.create_embedding,
+ fn=textual_inversion_ui.create_embedding,
inputs=[
new_embedding_name,
initialization_text,
@@ -1269,7 +1209,7 @@ def create_ui():
)
create_hypernetwork.click(
- fn=modules.hypernetworks.ui.create_hypernetwork,
+ fn=hypernetworks_ui.create_hypernetwork,
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
@@ -1289,7 +1229,7 @@ def create_ui():
)
run_preprocess.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
@@ -1325,7 +1265,7 @@ def create_ui():
)
train_embedding.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
@@ -1359,7 +1299,7 @@ def create_ui():
)
train_hypernetwork.click(
- fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py
index 4863d861..f9c5dd6b 100644
--- a/modules/ui_checkpoint_merger.py
+++ b/modules/ui_checkpoint_merger.py
@@ -29,7 +29,7 @@ def modelmerger(*args):
class UiCheckpointMerger:
def __init__(self):
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
- with gr.Row().style(equal_height=False):
+ with gr.Row(equal_height=False):
with gr.Column(variant='compact'):
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 11eb2a4b..1dda1627 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -134,7 +134,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
generation_info = None
with gr.Column():
@@ -223,20 +223,44 @@ Requested path was: {f}
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
+
+ label = None
+ for comp in refresh_components:
+ label = getattr(comp, 'label', None)
+ if label is not None:
+ break
+
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
- setattr(refresh_component, k, v)
+ for comp in refresh_components:
+ setattr(comp, k, v)
- return gr.update(**(args or {}))
+ return (gr.update(**(args or {})) for _ in refresh_components) if len(refresh_components) > 1 else gr.update(**(args or {}))
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
refresh_button.click(
fn=refresh,
inputs=[],
- outputs=[refresh_component]
+ outputs=refresh_components
)
return refresh_button
+
+def setup_dialog(button_show, dialog, *, button_close=None):
+ """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
+
+ dialog.visible = False
+
+ button_show.click(
+ fn=lambda: gr.update(visible=True),
+ inputs=[],
+ outputs=[dialog],
+ ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
+
+ if button_close:
+ button_close.click(fn=None, _js="closePopup")
+
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 64451df7..8f8a7088 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -35,7 +35,7 @@ class FormColumn(FormComponent, gr.Column):
class FormGroup(FormComponent, gr.Group):
- """Same as gr.Row but fits inside gradio forms"""
+ """Same as gr.Group but fits inside gradio forms"""
def get_block_name(self):
return "group"
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index f3e4fba7..15a8b0bf 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -164,7 +164,7 @@ def extension_table():
ext_status = ext.status
style = ""
- if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
style = STYLE_PRIMARY
version_link = ext.version
@@ -533,16 +533,20 @@ def create_ui():
apply = gr.Button(value=apply_label, variant="primary")
check = gr.Button(value="Check for updates")
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
- extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
- extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
html = ""
- if shared.opts.disable_all_extensions != "none":
- html = """
-<span style="color: var(--primary-400);">
- "Disable all extensions" was set, change it to "none" to load all extensions again
-</span>
- """
+
+ if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none":
+ if shared.cmd_opts.disable_all_extensions:
+ msg = '"--disable-all-extensions" was used, remove it to load all extensions again'
+ elif shared.opts.disable_all_extensions != "none":
+ msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'
+ elif shared.cmd_opts.disable_extra_extensions:
+ msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
+ html = f'<span style="color: var(--primary-400);">{msg}</span>'
+
info = gr.HTML(html)
extensions_table = gr.HTML('Loading...')
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
@@ -565,7 +569,7 @@ def create_ui():
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
- available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
+ available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
@@ -574,7 +578,7 @@ def create_ui():
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
with gr.Row():
- search_extensions_text = gr.Text(label="Search").style(container=False)
+ search_extensions_text = gr.Text(label="Search", container=False)
install_result = gr.HTML()
available_extensions_table = gr.HTML()
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index f2752f10..c6390db7 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -2,7 +2,7 @@ import os.path
import urllib.parse
from pathlib import Path
-from modules import shared, ui_extra_networks_user_metadata, errors
+from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr
@@ -101,16 +101,7 @@ class ExtraNetworksPage:
def read_user_metadata(self, item):
filename = item.get("filename", None)
- basename, ext = os.path.splitext(filename)
- metadata_filename = basename + '.json'
-
- metadata = {}
- try:
- if os.path.isfile(metadata_filename):
- with open(metadata_filename, "r", encoding="utf8") as file:
- metadata = json.load(file)
- except Exception as e:
- errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+ metadata = extra_networks.get_user_metadata(filename)
desc = metadata.get("description", None)
if desc is not None:
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 2bb0a222..77885022 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -3,6 +3,7 @@ import os
from modules import shared, ui_extra_networks, sd_models
from modules.ui_extra_networks import quote_js
+from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -12,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.refresh_checkpoints()
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
path, ext = os.path.splitext(checkpoint.filename)
return {
@@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
+ def create_user_metadata_editor(self, ui, tabname):
+ return CheckpointUserMetadataEditor(ui, tabname, self)
diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py
new file mode 100644
index 00000000..2c69aab8
--- /dev/null
+++ b/modules/ui_extra_networks_checkpoints_user_metadata.py
@@ -0,0 +1,60 @@
+import gradio as gr
+
+from modules import ui_extra_networks_user_metadata, sd_vae
+from modules.ui_common import create_refresh_button
+
+
+class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
+ def __init__(self, ui, tabname, page):
+ super().__init__(ui, tabname, page)
+
+ self.select_vae = None
+
+ def save_user_metadata(self, name, desc, notes, vae):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["notes"] = notes
+ user_metadata["vae"] = vae
+
+ self.write_user_metadata(name, user_metadata)
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+ values = super().put_values_into_components(name)
+
+ return [
+ *values[0:5],
+ user_metadata.get('vae', ''),
+ ]
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ with gr.Row():
+ self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
+ create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ self.create_default_buttons()
+
+ viewed_components = [
+ self.edit_name,
+ self.edit_description,
+ self.html_filedata,
+ self.html_preview,
+ self.edit_notes,
+ self.select_vae,
+ ]
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ edited_components = [
+ self.edit_description,
+ self.edit_notes,
+ self.select_vae,
+ ]
+
+ self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index e53ccb42..514a4562 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.reload_hypernetworks()
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path)
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index d1794e50..73134698 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
path, ext = os.path.splitext(embedding.filename)
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index c7dc1154..802e1ce7 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -6,7 +6,7 @@ import modules.generation_parameters_copypaste as parameters_copypaste
def create_ui():
tab_index = gr.State(value=0)
- with gr.Row().style(equal_height=False, variant='compact'):
+ with gr.Row(equal_height=False, variant='compact'):
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
new file mode 100644
index 00000000..85eb3a64
--- /dev/null
+++ b/modules/ui_prompt_styles.py
@@ -0,0 +1,110 @@
+import gradio as gr
+
+from modules import shared, ui_common, ui_components, styles
+
+styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
+styles_materialize_symbol = '\U0001f4cb' # 📋
+
+
+def select_style(name):
+ style = shared.prompt_styles.styles.get(name)
+ existing = style is not None
+ empty = not name
+
+ prompt = style.prompt if style else gr.update()
+ negative_prompt = style.negative_prompt if style else gr.update()
+
+ return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
+
+
+def save_style(name, prompt, negative_prompt):
+ if not name:
+ return gr.update(visible=False)
+
+ style = styles.PromptStyle(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return gr.update(visible=True)
+
+
+def delete_style(name):
+ if name == "":
+ return
+
+ shared.prompt_styles.styles.pop(name, None)
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return '', '', ''
+
+
+def materialize_styles(prompt, negative_prompt, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
+
+
+def refresh_styles():
+ return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
+
+
+class UiPromptStyles:
+ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
+ self.tabname = tabname
+
+ with gr.Row(elem_id=f"{tabname}_styles_row"):
+ self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
+ edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
+
+ with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
+ with gr.Row():
+ self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
+ ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
+ self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+
+ with gr.Row():
+ self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
+
+ with gr.Row():
+ self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
+
+ with gr.Row():
+ self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
+ self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
+ self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
+
+ self.selection.change(
+ fn=select_style,
+ inputs=[self.selection],
+ outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
+ show_progress=False,
+ )
+
+ self.save.click(
+ fn=save_style,
+ inputs=[self.selection, self.prompt, self.neg_prompt],
+ outputs=[self.delete],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.delete.click(
+ fn=delete_style,
+ _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
+ inputs=[self.selection],
+ outputs=[self.selection, self.prompt, self.neg_prompt],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.materialize.click(
+ fn=materialize_styles,
+ inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ show_progress=False,
+ ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+
+ ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
+
+
+
+
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index a6076bf3..6dde4b6a 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -158,7 +158,7 @@ class UiSettings:
loadsave.create_ui()
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
- gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
+ gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
with gr.Row():
with gr.Column(scale=1):