aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py2
-rw-r--r--modules/esrgan_model_arch.py1
-rw-r--r--modules/extensions.py6
-rw-r--r--modules/generation_parameters_copypaste.py7
-rw-r--r--modules/hypernetworks/hypernetwork.py4
-rw-r--r--modules/images.py18
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/processing.py24
-rw-r--r--modules/script_callbacks.py29
-rw-r--r--modules/sd_hijack.py2
-rw-r--r--modules/sd_hijack_inpainting.py1
-rw-r--r--modules/sd_hijack_unet.py11
-rw-r--r--modules/sd_models.py2
-rw-r--r--modules/sd_samplers_kdiffusion.py42
-rw-r--r--modules/shared.py7
-rw-r--r--modules/shared_items.py2
-rw-r--r--modules/ui.py12
-rw-r--r--modules/ui_extensions.py8
-rw-r--r--modules/ui_extra_networks.py10
19 files changed, 142 insertions, 48 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index eb7b1da5..5a9ac5f1 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -498,7 +498,7 @@ class Api:
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
- hypernetwork, filename = train_hypernetwork(*args)
+ hypernetwork, filename = train_hypernetwork(**args)
except Exception as e:
error = e
finally:
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index bc9ceb2a..1b52b0f5 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -1,5 +1,6 @@
# this file is adapted from https://github.com/victorca25/iNNfer
+from collections import OrderedDict
import math
import functools
import torch
diff --git a/modules/extensions.py b/modules/extensions.py
index 5e12b1aa..3eef9eaf 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
+import time
import git
from modules import paths, shared
@@ -25,6 +26,7 @@ class Extension:
self.status = ''
self.can_update = False
self.is_builtin = is_builtin
+ self.version = ''
repo = None
try:
@@ -40,6 +42,10 @@ class Extension:
try:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
+ head = repo.head.commit
+ ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
+ self.version = f'{head.hexsha[:8]} ({ts})'
+
except Exception:
self.remote = None
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index fc9e17aa..89dc23bf 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -74,8 +74,8 @@ def image_from_url_text(filedata):
return image
-def add_paste_fields(tabname, init_img, fields):
- paste_fields[tabname] = {"init_img": init_img, "fields": fields}
+def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions
import modules.ui
@@ -110,6 +110,7 @@ def connect_paste_params_buttons():
for binding in registered_param_bindings:
destination_image_component = paste_fields[binding.tabname]["init_img"]
fields = paste_fields[binding.tabname]["fields"]
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
@@ -130,7 +131,7 @@ def connect_paste_params_buttons():
)
if binding.source_text_component is not None and fields is not None:
- connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
if binding.source_tabname is not None and fields is not None:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f4fb69e0..f6ef42d5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -380,8 +380,8 @@ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
- context_k = hypernetwork_layers[0](context_k)
- context_v = hypernetwork_layers[1](context_v)
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
return context_k, context_v
diff --git a/modules/images.py b/modules/images.py
index c2ca8849..38404de3 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -18,7 +18,7 @@ import string
import json
import hashlib
-from modules import sd_samplers, shared, script_callbacks
+from modules import sd_samplers, shared, script_callbacks, errors
from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -553,6 +553,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
+ elif image_to_save.mode == 'I;16':
+ image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
@@ -575,17 +577,19 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image.already_saved_as = fullfn
- target_side_length = 4000
- oversize = image.width > target_side_length or image.height > target_side_length
- if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
+ oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
if oversize and ratio > 1:
- image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
+ image = image.resize((opts.target_side_length, image.height * opts.target_side_length // image.width), LANCZOS)
elif oversize:
- image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
+ image = image.resize((image.width * opts.target_side_length // image.height, opts.target_side_length), LANCZOS)
- _atomically_save_image(image, fullfn_without_extension, ".jpg")
+ try:
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
+ except Exception as e:
+ errors.display(e, "saving image as downscaled JPG")
if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt"
diff --git a/modules/img2img.py b/modules/img2img.py
index bcc158dc..c973b770 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -73,6 +73,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
if not save_normally:
os.makedirs(output_dir, exist_ok=True)
+ if processed_image.mode == 'RGBA':
+ processed_image = processed_image.convert("RGB")
processed_image.save(os.path.join(output_dir, filename))
diff --git a/modules/processing.py b/modules/processing.py
index e1b53ac0..2009d3bf 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -543,8 +543,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
- _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
-
if p.scripts is not None:
p.scripts.process(p)
@@ -582,13 +580,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
- if not p.disable_extra_networks:
- extra_networks.activate(p, extra_network_data)
-
- with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
if state.job_count == -1:
state.job_count = p.n_iter
@@ -609,11 +600,24 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
- prompts, _ = extra_networks.parse_prompts(prompts)
+ prompts, extra_network_data = extra_networks.parse_prompts(prompts)
+
+ if not p.disable_extra_networks:
+ with devices.autocast():
+ extra_networks.activate(p, extra_network_data)
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ # params.txt should be saved after scripts.process_batch, since the
+ # infotext could be modified by that callback
+ # Example: a wildcard processed by process_batch sets an extra model
+ # strength, which is saved as "Model Strength: 1.0" in the infotext
+ if n == 0:
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 4bb45ec7..edd0e2a7 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -46,6 +46,18 @@ class CFGDenoiserParams:
"""Total number of sampling steps planned"""
+class CFGDenoisedParams:
+ def __init__(self, x, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
self.txt2img_preview_params = txt2img_preview_params
@@ -68,6 +80,7 @@ callback_map = dict(
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
+ callbacks_cfg_denoised=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -150,6 +163,14 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
report_exception(c, 'cfg_denoiser_callback')
+def cfg_denoised_callback(params: CFGDenoisedParams):
+ for c in callback_map['callbacks_cfg_denoised']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_denoised_callback')
+
+
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
@@ -283,6 +304,14 @@ def on_cfg_denoiser(callback):
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+def on_cfg_denoised(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_denoised'], callback)
+
+
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 57ed5635..79476783 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -154,6 +154,8 @@ class StableDiffusionModelHijack:
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
apply_weighted_forward(m)
+ if m.cond_stage_key == "edit":
+ sd_hijack_unet.hijack_ddpm_edit()
self.optimization_method = apply_optimizations()
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 478cd499..55a2ce4d 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -11,6 +11,7 @@ import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
@torch.no_grad()
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 45cf2b18..843ab66c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
@@ -53,6 +54,16 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
else:
return torch.nn.GELU.forward(self, x)
+
+ddpm_edit_hijack = None
+def hijack_ddpm_edit():
+ global ddpm_edit_hijack
+ if not ddpm_edit_hijack:
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+ ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+
+
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index d847d358..127e9663 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -105,7 +105,7 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
checkpoint_alisases.clear()
- model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
+ model_list = modelloader.load_models(model_path=model_path, model_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index f076fc55..528f513f 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -8,6 +8,7 @@ from modules import prompt_parser, devices, sd_samplers_common
from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -136,6 +137,9 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ cfg_denoised_callback(denoised_params)
+
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
@@ -269,6 +273,16 @@ class KDiffusionSampler:
return sigmas
+ def create_noise_sampler(self, x, sigmas, p):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ if shared.opts.no_dpmpp_sde_batch_determinism:
+ return None
+
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
@@ -278,18 +292,24 @@ class KDiffusionSampler:
xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
+ parameters = inspect.signature(self.func).parameters
+
+ if 'sigma_min' in parameters:
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
- if 'sigma_max' in inspect.signature(self.func).parameters:
+ if 'sigma_max' in parameters:
extra_params_kwargs['sigma_max'] = sigma_sched[0]
- if 'n' in inspect.signature(self.func).parameters:
+ if 'n' in parameters:
extra_params_kwargs['n'] = len(sigma_sched) - 1
- if 'sigma_sched' in inspect.signature(self.func).parameters:
+ if 'sigma_sched' in parameters:
extra_params_kwargs['sigma_sched'] = sigma_sched
- if 'sigmas' in inspect.signature(self.func).parameters:
+ if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigma_sched
+ if self.funcname == 'sample_dpmpp_sde':
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+
self.model_wrap_cfg.init_latent = x
self.last_latent = x
extra_args={
@@ -303,7 +323,7 @@ class KDiffusionSampler:
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps = steps or p.steps
sigmas = self.get_sigmas(p, steps)
@@ -311,14 +331,20 @@ class KDiffusionSampler:
x = x * sigmas[0]
extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
+ parameters = inspect.signature(self.func).parameters
+
+ if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in inspect.signature(self.func).parameters:
+ if 'n' in parameters:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
+ if self.funcname == 'sample_dpmpp_sde':
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning,
diff --git a/modules/shared.py b/modules/shared.py
index 79fbf724..e324a48a 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -325,7 +325,9 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
- "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
+ "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
+ "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
+ "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
@@ -364,7 +366,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
- "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
+ "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
}))
@@ -414,6 +416,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+ "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
}))
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 8b5ec96d..e792a134 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -20,4 +20,4 @@ def sd_vae_items():
def refresh_vae_list():
import modules.sd_vae
- return modules.sd_vae.refresh_vae_list
+ modules.sd_vae.refresh_vae_list()
diff --git a/modules/ui.py b/modules/ui.py
index efb87c23..0516c643 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -631,9 +631,9 @@ def create_ui():
(hr_resize_y, "Hires resize-2"),
*modules.scripts.scripts_txt2img.infotext_fields
]
- parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings,
+ paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
))
txt2img_preview_params = [
@@ -963,10 +963,10 @@ def create_ui():
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields
]
- parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
- parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings,
+ paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
))
modules.scripts.scripts_current = None
@@ -1786,7 +1786,7 @@ def versions_html():
return f"""
python: <span title="{sys.version}">{python_version}</span>
 • 
-torch: {torch.__version__}
+torch: {getattr(torch, '__long_version__',torch.__version__)}
 • 
xformers: {xformers_version}
 • 
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 37d30e1f..bd4308ef 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -80,6 +80,7 @@ def extension_table():
<tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
+ <th><abbr title="Extension version">Version</abbr></th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
@@ -87,11 +88,7 @@ def extension_table():
"""
for ext in extensions.extensions:
- remote = ""
- if ext.is_builtin:
- remote = "built-in"
- elif ext.remote:
- remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
+ remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
@@ -102,6 +99,7 @@ def extension_table():
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td>
+ <td>{ext.version}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 90abec0a..71f1d81f 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -76,6 +76,10 @@ class ExtraNetworksPage:
while subdir.startswith("/"):
subdir = subdir[1:]
+ is_empty = len(os.listdir(x)) == 0
+ if not is_empty and not subdir.endswith("/"):
+ subdir = subdir + "/"
+
subdirs[subdir] = 1
if subdirs:
@@ -94,11 +98,13 @@ class ExtraNetworksPage:
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+ self_name_id = self.name.replace(" ", "_")
+
res = f"""
-<div id='{tabname}_{self.name}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
+<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
{subdirs_html}
</div>
-<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
+<div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
{items_html}
</div>
"""