aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py7
-rw-r--r--modules/extras.py37
-rw-r--r--modules/generation_parameters_copypaste.py4
-rw-r--r--modules/images.py51
-rw-r--r--modules/processing.py62
-rw-r--r--modules/safety.py42
-rw-r--r--modules/scripts.py20
-rw-r--r--modules/sd_hijack.py14
-rw-r--r--modules/sd_hijack_unet.py30
-rw-r--r--modules/sd_models.py48
-rw-r--r--modules/sd_vae.py37
-rw-r--r--modules/shared.py3
-rw-r--r--modules/ui.py56
-rw-r--r--modules/ui_extensions.py13
14 files changed, 273 insertions, 151 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 54ee7cb0..89935a70 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -157,12 +157,7 @@ class Api:
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
p = StableDiffusionProcessingImg2Img(**args)
- imgs = []
- for img in init_images:
- img = decode_base64_to_image(img)
- imgs = [img] * p.batch_size
-
- p.init_images = imgs
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
diff --git a/modules/extras.py b/modules/extras.py
index bc349d5e..0ad8deec 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -62,7 +62,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Also keep track of original file names
imageNameArr = []
outputs = []
-
+
if extras_mode == 1:
#convert file to pillow image
for img in image_folder:
@@ -234,7 +234,7 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -246,30 +246,25 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
+ tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
result_is_inpainting_model = False
- print(f"Loading {primary_model_info.filename}...")
- theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
-
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
-
- if teritary_model_info is not None:
- print(f"Loading {teritary_model_info.filename}...")
- theta_2 = sd_models.read_state_dict(teritary_model_info.filename, map_location='cpu')
- else:
- theta_2 = None
-
theta_funcs = {
"Weighted sum": (None, weighted_sum),
"Add difference": (get_difference, add_difference),
}
theta_func1, theta_func2 = theta_funcs[interp_method]
- print(f"Merging...")
+ if theta_func1 and not tertiary_model_info:
+ return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func1:
+ print(f"Loading {tertiary_model_info.filename}...")
+ theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
if key in theta_2:
@@ -277,7 +272,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
- del theta_2
+ del theta_2
+
+ print(f"Loading {primary_model_info.filename}...")
+ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
+
+ print("Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
@@ -307,6 +307,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
+ del theta_1
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
@@ -332,5 +333,5 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
sd_models.list_models()
- print(f"Checkpoint saved.")
+ print("Checkpoint saved.")
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 44fe1a6c..565e342d 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -77,6 +77,7 @@ def integrate_settings_paste_fields(component_dict):
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
'eta_noise_seed_delta': 'ENSD',
+ 'initial_noise_multiplier': 'Noise multiplier',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
@@ -121,8 +122,7 @@ def run_bind():
if send_generate_info and paste_fields[tab]["fields"] is not None:
if send_generate_info in paste_fields:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
-
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
diff --git a/modules/images.py b/modules/images.py
index 08a72e67..8146f580 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -501,30 +501,39 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image = params.image
fullfn = params.filename
info = params.pnginfo.get(pnginfo_section_name, None)
- fullfn_without_extension, extension = os.path.splitext(params.filename)
- def exif_bytes():
- return piexif.dump({
- "Exif": {
- piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
- },
- })
+ def _atomically_save_image(image_to_save, filename_without_extension, extension):
+ # save image with .tmp extension to avoid race condition when another process detects new image in the directory
+ temp_file_path = filename_without_extension + ".tmp"
+ image_format = Image.registered_extensions()[extension]
- if extension.lower() == '.png':
- pnginfo_data = PngImagePlugin.PngInfo()
- if opts.enable_pnginfo:
- for k, v in params.pnginfo.items():
- pnginfo_data.add_text(k, str(v))
+ if extension.lower() == '.png':
+ pnginfo_data = PngImagePlugin.PngInfo()
+ if opts.enable_pnginfo:
+ for k, v in params.pnginfo.items():
+ pnginfo_data.add_text(k, str(v))
- image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
- elif extension.lower() in (".jpg", ".jpeg", ".webp"):
- image.save(fullfn, quality=opts.jpeg_quality)
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
- if opts.enable_pnginfo and info is not None:
- piexif.insert(exif_bytes(), fullfn)
- else:
- image.save(fullfn, quality=opts.jpeg_quality)
+ if opts.enable_pnginfo and info is not None:
+ exif_bytes = piexif.dump({
+ "Exif": {
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
+ },
+ })
+
+ piexif.insert(exif_bytes, temp_file_path)
+ else:
+ image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+
+ # atomically rename the file with correct extension
+ os.replace(temp_file_path, filename_without_extension + extension)
+
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
+ _atomically_save_image(image, fullfn_without_extension, extension)
image.already_saved_as = fullfn
@@ -538,9 +547,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif oversize:
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
- image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
- if opts.enable_pnginfo and info is not None:
- piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt"
diff --git a/modules/processing.py b/modules/processing.py
index 3d2c4dc9..24c537d1 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,15 +13,20 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
import modules.images as images
import modules.styles
+import modules.sd_models as sd_models
+import modules.sd_vae as sd_vae
import logging
+from ldm.data.util import AddMiDaS
+from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
+from einops import repeat, rearrange
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -150,11 +155,26 @@ class StableDiffusionProcessing():
return image_conditioning
- def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
+ def depth2img_image_conditioning(self, source_image):
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
+ transformer = AddMiDaS(model_type="dpt_hybrid")
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
+
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning = torch.nn.functional.interpolate(
+ self.sd_model.depth_model(midas_in),
+ size=conditioning_image.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ (depth_min, depth_max) = torch.aminmax(conditioning)
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
+ return conditioning
+
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -191,6 +211,18 @@ class StableDiffusionProcessing():
return image_conditioning
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
+ return self.depth2img_image_conditioning(source_image)
+
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+
+ # Dummy zero conditioning if we're not using inpainting or depth model.
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -424,8 +456,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
- setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible
- if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
+ setattr(opts, k, v)
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
+ if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
res = process_images_inner(p)
@@ -433,6 +467,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
+ if k == 'sd_vae': sd_vae.reload_vae_weights()
return res
@@ -541,9 +577,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- if opts.filter_nsfw:
- import modules.safety as safety
- x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
+ if p.scripts is not None:
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -734,7 +769,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -749,6 +784,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
self.inpainting_mask_invert = inpainting_mask_invert
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
self.mask = None
self.nmask = None
self.image_conditioning = None
@@ -862,6 +898,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ if self.initial_noise_multiplier != 1.0:
+ self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
+ x *= self.initial_noise_multiplier
+
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
diff --git a/modules/safety.py b/modules/safety.py
deleted file mode 100644
index cff4b278..00000000
--- a/modules/safety.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
-from PIL import Image
-
-import modules.shared as shared
-
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
-safety_feature_extractor = None
-safety_checker = None
-
-def numpy_to_pil(images):
- """
- Convert a numpy image or a batch of images to a PIL image.
- """
- if images.ndim == 3:
- images = images[None, ...]
- images = (images * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images]
-
- return pil_images
-
-# check and replace nsfw content
-def check_safety(x_image):
- global safety_feature_extractor, safety_checker
-
- if safety_feature_extractor is None:
- safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
-
- safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
- x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
-
- return x_checked_image, has_nsfw_concept
-
-
-def censor_batch(x):
- x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
- x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
- x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
-
- return x
diff --git a/modules/scripts.py b/modules/scripts.py
index b934d881..23ca195d 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -88,6 +88,17 @@ class Script:
pass
+ def postprocess_batch(self, p, *args, **kwargs):
+ """
+ Same as process_batch(), but called for every batch after it has been generated.
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -347,6 +358,15 @@ class ScriptRunner:
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def postprocess_batch(self, p, images, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 95a17093..690a9ec2 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -1,17 +1,11 @@
-import math
-import os
-import sys
-import traceback
import torch
-import numpy as np
-from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
-from modules.shared import opts, device, cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip
+from modules.shared import cmd_opts
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
from modules.sd_hijack_optimizations import invokeAI_mps_available
@@ -35,10 +29,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
+
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
+ ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
new file mode 100644
index 00000000..1b9d7757
--- /dev/null
+++ b/modules/sd_hijack_unet.py
@@ -0,0 +1,30 @@
+import torch
+
+
+class TorchHijackForUnet:
+ """
+ This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
+ this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
+ """
+
+ def __getattr__(self, item):
+ if item == 'cat':
+ return self.cat
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+ def cat(self, tensors, *args, **kwargs):
+ if len(tensors) == 2:
+ a, b = tensors
+ if a.shape[-2:] != b.shape[-2:]:
+ a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
+
+ tensors = (a, b)
+
+ return torch.cat(tensors, *args, **kwargs)
+
+
+th = TorchHijackForUnet()
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 283cf1cd..5b37f3fe 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,6 +7,9 @@ import torch
import re
import safetensors.torch
from omegaconf import OmegaConf
+from os import mkdir
+from urllib import request
+import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
@@ -36,6 +39,7 @@ def setup_model():
os.makedirs(model_path)
list_models()
+ enable_midas_autodownload()
def checkpoint_tiles():
@@ -223,10 +227,54 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ sd_vae.delete_base_vae()
+ sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file)
+def enable_midas_autodownload():
+ """
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
+
+ When the 512-depth-ema model, and other future models like it, is loaded,
+ it calls midas.api.load_model to load the associated midas depth model.
+ This function applies a wrapper to download the model to the correct
+ location automatically.
+ """
+
+ midas_path = os.path.join(models_path, 'midas')
+
+ # stable-diffusion-stability-ai hard-codes the midas model path to
+ # a location that differs from where other scripts using this model look.
+ # HACK: Overriding the path here.
+ for k, v in midas.api.ISL_PATHS.items():
+ file_name = os.path.basename(v)
+ midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
+
+ midas_urls = {
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
+ }
+
+ midas.api.load_model_inner = midas.api.load_model
+
+ def load_model_wrapper(model_type):
+ path = midas.api.ISL_PATHS[model_type]
+ if not os.path.exists(path):
+ if not os.path.exists(midas_path):
+ mkdir(midas_path)
+
+ print(f"Downloading midas model weights for {model_type} to {path}")
+ request.urlretrieve(midas_urls[model_type], path)
+ print(f"{model_type} downloaded")
+
+ return midas.api.load_model_inner(model_type)
+
+ midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 9c120975..25638a83 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -4,6 +4,7 @@ from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
import glob
+from copy import deepcopy
model_dir = "Stable-diffusion"
@@ -15,7 +16,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-default_vae_dict = {"auto": "auto", "None": "None"}
+default_vae_dict = {"auto": "auto", "None": None, None: None}
default_vae_list = ["auto", "None"]
@@ -39,7 +40,8 @@ def get_base_vae(model):
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
- base_vae = model.first_stage_model.state_dict().copy()
+ assert not loaded_vae_file, "Trying to store non-base VAE!"
+ base_vae = deepcopy(model.first_stage_model.state_dict())
checkpoint_info = model.sd_checkpoint_info
@@ -50,9 +52,11 @@ def delete_base_vae():
def restore_base_vae(model):
- global base_vae, checkpoint_info
+ global loaded_vae_file
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
- load_vae_dict(model, base_vae)
+ print("Restoring base VAE")
+ _load_vae_dict(model, base_vae)
+ loaded_vae_file = None
delete_base_vae()
@@ -148,9 +152,10 @@ def load_vae(model, vae_file=None):
if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
+ store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- load_vae_dict(model, vae_dict_1)
+ _load_vae_dict(model, vae_dict_1)
# If vae used is not in dict, update it
# It will be removed on refresh though
@@ -158,30 +163,22 @@ def load_vae(model, vae_file=None):
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
+ elif loaded_vae_file:
+ restore_base_vae(model)
loaded_vae_file = vae_file
- """
- # Save current VAE to VAE settings, maybe? will it work?
- if save_settings:
- if vae_file is None:
- vae_opt = "None"
-
- # shared.opts.sd_vae = vae_opt
- """
-
first_load = False
# don't call this from outside
-def load_vae_dict(model, vae_dict_1=None):
- if vae_dict_1:
- store_base_vae(model)
- model.first_stage_model.load_state_dict(vae_dict_1)
- else:
- restore_base_vae()
+def _load_vae_dict(model, vae_dict_1):
+ model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
+def clear_loaded_vae():
+ global loaded_vae_file
+ loaded_vae_file = None
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
diff --git a/modules/shared.py b/modules/shared.py
index dc45fcaa..272267c1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -359,6 +359,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
@@ -366,7 +367,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
@@ -395,6 +395,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
+ "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
diff --git a/modules/ui.py b/modules/ui.py
index b2b8de90..28481e33 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -82,6 +82,7 @@ folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
+clear_prompt_symbol = '\U0001F5D1' # 🗑️
def plaintext_to_html(text):
@@ -302,8 +303,8 @@ def create_seed_inputs():
with gr.Row(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0)
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0)
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -316,6 +317,17 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
+
+def connect_clear_prompt(button):
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
+ button.click(
+ _js="clear_prompt",
+ fn=None,
+ inputs=[],
+ outputs=[],
+ )
+
+
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
@@ -391,10 +403,17 @@ def create_toprow(is_img2img):
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
-
+ clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
+
button_interrogate = None
button_deepbooru = None
if is_img2img:
@@ -616,10 +635,14 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+
+
+
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
pass
@@ -635,8 +658,8 @@ def create_ui():
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -644,8 +667,8 @@ def create_ui():
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
- firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
- firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
+ firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0)
+ firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(equal_height=True):
@@ -770,7 +793,8 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -835,8 +859,8 @@ def create_ui():
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -1171,8 +1195,8 @@ def create_ui():
with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
- process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
@@ -1230,8 +1254,8 @@ def create_ui():
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
- training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1450,7 +1474,7 @@ def create_ui():
opts.save(shared.config_filename)
except RuntimeError:
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
- return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index b487ac25..1434f25f 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -206,12 +206,13 @@ def refresh_available_extensions_from_data(hide_tags):
if url is None:
continue
+ existing = installed_extension_urls.get(normalize_git_url(url), None)
+ extension_tags = extension_tags + ["installed"] if existing else extension_tags
+
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
hidden += 1
continue
- existing = installed_extension_urls.get(normalize_git_url(url), None)
-
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
@@ -222,7 +223,11 @@ def refresh_available_extensions_from_data(hide_tags):
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
- """
+
+ """
+
+ for tag in [x for x in extension_tags if x not in tags]:
+ tags[tag] = tag
code += """
</tbody>
@@ -272,7 +277,7 @@ def create_ui():
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
with gr.Row():
- hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
install_result = gr.HTML()
available_extensions_table = gr.HTML()