aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py70
-rw-r--r--modules/api/models.py1
-rw-r--r--modules/errors.py12
-rw-r--r--modules/extras.py228
-rw-r--r--modules/hypernetworks/hypernetwork.py2
-rw-r--r--modules/interrogate.py44
-rw-r--r--modules/paths.py14
-rw-r--r--modules/postprocessing.py103
-rw-r--r--modules/processing.py8
-rw-r--r--modules/scripts.py28
-rw-r--r--modules/scripts_postprocessing.py147
-rw-r--r--modules/sd_disable_initialization.py4
-rw-r--r--modules/sd_hijack_optimizations.py22
-rw-r--r--modules/shared.py25
-rw-r--r--modules/ui.py313
-rw-r--r--modules/ui_common.py202
-rw-r--r--modules/ui_components.py1
-rw-r--r--modules/ui_extra_networks.py29
-rw-r--r--modules/ui_extra_networks_hypernets.py3
-rw-r--r--modules/ui_extra_networks_textual_inversion.py3
-rw-r--r--modules/ui_postprocessing.py57
21 files changed, 751 insertions, 565 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index f2e9e884..25c65e57 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.extras import run_extras
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
@@ -23,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import List
+import piexif
+import piexif.helper
def upscaler_to_index(name: str):
try:
@@ -45,32 +46,46 @@ def validate_sampler_name(name):
def setUpscalers(req: dict):
reqDict = vars(req)
- reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
- reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
- reqDict.pop('upscaler_1')
- reqDict.pop('upscaler_2')
+ reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
+ reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
- return Image.open(BytesIO(base64.b64decode(encoding)))
+ try:
+ image = Image.open(BytesIO(base64.b64decode(encoding)))
+ return image
+ except Exception as err:
+ raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
- # Copy any text-only metadata
- use_metadata = False
- metadata = PngImagePlugin.PngInfo()
- for key, value in image.info.items():
- if isinstance(key, str) and isinstance(value, str):
- metadata.add_text(key, value)
- use_metadata = True
+ if opts.samples_format.lower() == 'png':
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+ image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
+
+ elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+ parameters = image.info.get('parameters', None)
+ exif_bytes = piexif.dump({
+ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
+ })
+ if opts.samples_format.lower() in ("jpg", "jpeg"):
+ image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
+ else:
+ image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
+
+ else:
+ raise HTTPException(status_code=500, detail="Invalid image format")
- image.save(
- output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
- )
bytes_data = output_bytes.getvalue()
+
return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
@@ -244,7 +259,7 @@ class Api:
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
- result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
@@ -260,7 +275,7 @@ class Api:
reqDict.pop('imageList')
with self.queue_lock:
- result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
@@ -360,13 +375,16 @@ class Api:
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
- upscalers = []
-
- for upscaler in shared.sd_upscalers:
- u = upscaler.scaler
- upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
-
- return upscalers
+ return [
+ {
+ "name": upscaler.name,
+ "model_name": upscaler.scaler.model_name,
+ "model_path": upscaler.data_path,
+ "model_url": None,
+ "scale": upscaler.scale,
+ }
+ for upscaler in shared.sd_upscalers
+ ]
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
diff --git a/modules/api/models.py b/modules/api/models.py
index 1eb1fcf1..805bd8f7 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -220,6 +220,7 @@ class UpscalerItem(BaseModel):
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL")
+ scale: Optional[float] = Field(title="Scale")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
diff --git a/modules/errors.py b/modules/errors.py
index a10e8708..f6b80dbb 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable
""")
+already_displayed = {}
+
+
+def display_once(e: Exception, task):
+ if task in already_displayed:
+ return
+
+ display(e, task)
+
+ already_displayed[task] = 1
+
+
def run(code, task):
try:
code()
diff --git a/modules/extras.py b/modules/extras.py
index 1218f88f..36123aa5 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,230 +1,16 @@
-from __future__ import annotations
-import math
import os
-import sys
-import traceback
+import re
import shutil
-import numpy as np
-from PIL import Image
import torch
import tqdm
-from typing import Callable, List, OrderedDict, Tuple
-from functools import partial
-from dataclasses import dataclass
-
-from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
-from modules.shared import opts
-import modules.gfpgan_model
-from modules.ui import plaintext_to_html
-import modules.codeformer_model
+from modules import shared, images, sd_models, sd_vae
+from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
-class LruCache(OrderedDict):
- @dataclass(frozen=True)
- class Key:
- image_hash: int
- info_hash: int
- args_hash: int
-
- @dataclass
- class Value:
- image: Image.Image
- info: str
-
- def __init__(self, max_size: int = 5, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._max_size = max_size
-
- def get(self, key: LruCache.Key) -> LruCache.Value:
- ret = super().get(key)
- if ret is not None:
- self.move_to_end(key) # Move to end of eviction list
- return ret
-
- def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
- self[key] = value
- while len(self) > self._max_size:
- self.popitem(last=False)
-
-
-cached_images: LruCache = LruCache(max_size=5)
-
-
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
- devices.torch_gc()
-
- shared.state.begin()
- shared.state.job = 'extras'
-
- imageArr = []
- # Also keep track of original file names
- imageNameArr = []
- outputs = []
-
- if extras_mode == 1:
- #convert file to pillow image
- for img in image_folder:
- image = Image.open(img)
- imageArr.append(image)
- imageNameArr.append(os.path.splitext(img.orig_name)[0])
- elif extras_mode == 2:
- assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
-
- if input_dir == '':
- return outputs, "Please select an input directory.", ''
- image_list = shared.listfiles(input_dir)
- for img in image_list:
- try:
- image = Image.open(img)
- except Exception:
- continue
- imageArr.append(image)
- imageNameArr.append(img)
- else:
- imageArr.append(image)
- imageNameArr.append(None)
-
- if extras_mode == 2 and output_dir != '':
- outpath = output_dir
- else:
- outpath = opts.outdir_samples or opts.outdir_extras_samples
-
- # Extra operation definitions
-
- def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- shared.state.job = 'extras-gfpgan'
- restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
-
- if gfpgan_visibility < 1.0:
- res = Image.blend(image, res, gfpgan_visibility)
-
- info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
- return (res, info)
-
- def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- shared.state.job = 'extras-codeformer'
- restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
-
- if codeformer_visibility < 1.0:
- res = Image.blend(image, res, codeformer_visibility)
-
- info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
- return (res, info)
-
- def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
- shared.state.job = 'extras-upscale'
- upscaler = shared.sd_upscalers[scaler_index]
- res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- if mode == 1 and crop:
- cropped = Image.new("RGB", (resize_w, resize_h))
- cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
- res = cropped
- return res
-
- def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
- nonlocal upscaling_resize
- if resize_mode == 1:
- upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
- crop_info = " (crop)" if upscaling_crop else ""
- info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
- return (image, info)
-
- @dataclass
- class UpscaleParams:
- upscaler_idx: int
- blend_alpha: float
-
- def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- blended_result: Image.Image = None
- image_hash: str = hash(np.array(image.getdata()).tobytes())
- for upscaler in params:
- upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
- upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = LruCache.Key(image_hash=image_hash,
- info_hash=hash(info),
- args_hash=hash(upscale_args))
- cached_entry = cached_images.get(cache_key)
- if cached_entry is None:
- res = upscale(image, *upscale_args)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
- cached_images.put(cache_key, LruCache.Value(image=res, info=info))
- else:
- res, info = cached_entry.image, cached_entry.info
-
- if blended_result is None:
- blended_result = res
- else:
- blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
- return (blended_result, info)
-
- # Build a list of operations to run
- facefix_ops: List[Callable] = []
- facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
- facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
-
- upscale_ops: List[Callable] = []
- upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
-
- if upscaling_resize != 0:
- step_params: List[UpscaleParams] = []
- step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
-
- upscale_ops.append(partial(run_upscalers_blend, step_params))
-
- extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
-
- for image, image_name in zip(imageArr, imageNameArr):
- if image is None:
- return outputs, "Please select an input image.", ''
-
- shared.state.textinfo = f'Processing image {image_name}'
-
- existing_pnginfo = image.info or {}
-
- image = image.convert("RGB")
- info = ""
- # Run each operation on each image
- for op in extras_ops:
- image, info = op(image, info)
-
- if opts.use_original_name_batch and image_name is not None:
- basename = os.path.splitext(os.path.basename(image_name))[0]
- else:
- basename = ''
-
- if opts.enable_pnginfo: # append info before save
- image.info = existing_pnginfo
- image.info["extras"] = info
-
- if save_output:
- # Add upscaler name as a suffix.
- suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
- # Add second upscaler if applicable.
- if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
- suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
-
- images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
-
- if extras_mode != 2 or show_extras_results :
- outputs.append(image)
-
- devices.torch_gc()
-
- return outputs, plaintext_to_html(info), ''
-
-def clear_cache():
- cached_images.clear()
-
def run_pnginfo(image):
if image is None:
@@ -285,7 +71,7 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -430,6 +216,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
for key in theta_0.keys():
theta_0[key] = to_half(theta_0[key], save_as_half)
+ if discard_weights:
+ regex = re.compile(discard_weights)
+ for key in list(theta_0):
+ if re.search(regex, key):
+ theta_0.pop(key, None)
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = filename_generator() if custom_name == '' else custom_name
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 80a47c79..503534e2 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -715,6 +715,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
do_not_save_samples=True,
)
+ p.disable_extra_networks = True
+
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 19938cbb..236e6983 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
from collections import namedtuple
+from pathlib import Path
import re
import torch
@@ -20,19 +21,20 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+def category_types():
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
+
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
+ category_types = ["artists", "flavors", "mediums", "movements"]
+
try:
os.makedirs(tmpdir)
-
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
-
+ for category_type in category_types:
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
except Exception as e:
@@ -51,32 +53,38 @@ class InterrogateModels:
def __init__(self, content_dir):
self.loaded_categories = None
+ self.skip_categories = []
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
def categories(self):
- if self.loaded_categories is not None:
- return self.loaded_categories
-
- self.loaded_categories = []
-
if not os.path.exists(self.content_dir):
download_default_clip_interrogate_categories(self.content_dir)
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
+ return self.loaded_categories
+
+ self.loaded_categories = []
+
if os.path.exists(self.content_dir):
- for filename in os.listdir(self.content_dir):
- m = re_topn.search(filename)
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
+ category_types = []
+ for filename in Path(self.content_dir).glob('*.txt'):
+ category_types.append(filename.stem)
+ if filename.stem in self.skip_categories:
+ continue
+ m = re_topn.search(filename.stem)
topn = 1 if m is None else int(m.group(1))
-
- with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
+ with open(filename, "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
return self.loaded_categories
def load_blip_model(self):
- import models.blip
+ with paths.Prioritize("BLIP"):
+ import models.blip
files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "BLIP"),
@@ -139,6 +147,8 @@ class InterrogateModels:
def rank(self, image_features, text_array, top_count=1):
import clip
+ devices.torch_gc()
+
if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
diff --git a/modules/paths.py b/modules/paths.py
index 4dd03a35..20b3e4d8 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,3 +38,17 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
+
+
+class Prioritize:
+ def __init__(self, name):
+ self.name = name
+ self.path = None
+
+ def __enter__(self):
+ self.path = sys.path.copy()
+ sys.path = [paths[self.name]] + sys.path
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ sys.path = self.path
+ self.path = None
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
new file mode 100644
index 00000000..09d8e605
--- /dev/null
+++ b/modules/postprocessing.py
@@ -0,0 +1,103 @@
+import os
+
+from PIL import Image
+
+from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
+from modules.shared import opts
+
+
+def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
+ devices.torch_gc()
+
+ shared.state.begin()
+ shared.state.job = 'extras'
+
+ image_data = []
+ image_names = []
+ outputs = []
+
+ if extras_mode == 1:
+ for img in image_folder:
+ image = Image.open(img)
+ image_data.append(image)
+ image_names.append(os.path.splitext(img.orig_name)[0])
+ elif extras_mode == 2:
+ assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
+ assert input_dir, 'input directory not selected'
+
+ image_list = shared.listfiles(input_dir)
+ for filename in image_list:
+ try:
+ image = Image.open(filename)
+ except Exception:
+ continue
+ image_data.append(image)
+ image_names.append(filename)
+ else:
+ assert image, 'image not selected'
+
+ image_data.append(image)
+ image_names.append(None)
+
+ if extras_mode == 2 and output_dir != '':
+ outpath = output_dir
+ else:
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
+
+ infotext = ''
+
+ for image, name in zip(image_data, image_names):
+ shared.state.textinfo = name
+
+ existing_pnginfo = image.info or {}
+
+ pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
+
+ scripts.scripts_postproc.run(pp, args)
+
+ if opts.use_original_name_batch and name is not None:
+ basename = os.path.splitext(os.path.basename(name))[0]
+ else:
+ basename = ''
+
+ infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+
+ if opts.enable_pnginfo:
+ pp.image.info = existing_pnginfo
+ pp.image.info["postprocessing"] = infotext
+
+ if save_output:
+ images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
+
+ if extras_mode != 2 or show_extras_results:
+ outputs.append(pp.image)
+
+ devices.torch_gc()
+
+ return outputs, ui_common.plaintext_to_html(infotext), ''
+
+
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
+ """old handler for API"""
+
+ args = scripts.scripts_postproc.create_args_for_run({
+ "Upscale": {
+ "upscale_mode": resize_mode,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ },
+ "GFPGAN": {
+ "gfpgan_visibility": gfpgan_visibility,
+ },
+ "CodeFormer": {
+ "codeformer_visibility": codeformer_visibility,
+ "codeformer_weight": codeformer_weight,
+ },
+ })
+
+ return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
diff --git a/modules/processing.py b/modules/processing.py
index 01c1b53c..a1138491 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -140,6 +140,7 @@ class StableDiffusionProcessing:
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
+ self.disable_extra_networks = False
if not seed_enable_extras:
self.subseed = -1
@@ -579,7 +580,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
- extra_networks.activate(p, extra_network_data)
+ if not p.disable_extra_networks:
+ extra_networks.activate(p, extra_network_data)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
@@ -723,7 +725,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- extra_networks.deactivate(p, extra_network_data)
+ if not p.disable_extra_networks:
+ extra_networks.deactivate(p, extra_network_data)
+
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
diff --git a/modules/scripts.py b/modules/scripts.py
index 4ffc369b..03907a63 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -7,7 +7,7 @@ from collections import namedtuple
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks, extensions, script_loading
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object()
@@ -150,8 +150,10 @@ def basedir():
return current_basedir
-scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+
+scripts_data = []
+postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
@@ -190,23 +192,31 @@ def list_files_with_name(filename):
def load_scripts():
global current_basedir
scripts_data.clear()
+ postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py")
syspath = sys.path
+ def register_scripts_from_module(module):
+ for key, script_class in module.__dict__.items():
+ if type(script_class) != type:
+ continue
+
+ if issubclass(script_class, Script):
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+
for scriptfile in sorted(scripts_list):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
- module = script_loading.load_module(scriptfile.path)
-
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ script_module = script_loading.load_module(scriptfile.path)
+ register_scripts_from_module(script_module)
except Exception:
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -413,6 +423,7 @@ class ScriptRunner:
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
scripts_current: ScriptRunner = None
@@ -423,12 +434,13 @@ def reload_script_body_only():
def reload_scripts():
- global scripts_txt2img, scripts_img2img
+ global scripts_txt2img, scripts_img2img, scripts_postproc
load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def IOComponent_init(self, *args, **kwargs):
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
new file mode 100644
index 00000000..25de02d0
--- /dev/null
+++ b/modules/scripts_postprocessing.py
@@ -0,0 +1,147 @@
+import os
+import gradio as gr
+
+from modules import errors, shared
+
+
+class PostprocessedImage:
+ def __init__(self, image):
+ self.image = image
+ self.info = {}
+
+
+class ScriptPostprocessing:
+ filename = None
+ controls = None
+ args_from = None
+ args_to = None
+
+ order = 1000
+ """scripts will be ordred by this value in postprocessing UI"""
+
+ name = None
+ """this function should return the title of the script."""
+
+ group = None
+ """A gr.Group component that has all script's UI inside it"""
+
+ def ui(self):
+ """
+ This function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be a dictionary that maps parameter names to components used in processing.
+ Values of those components will be passed to process() function.
+ """
+
+ pass
+
+ def process(self, pp: PostprocessedImage, **args):
+ """
+ This function is called to postprocess the image.
+ args contains a dictionary with all values returned by components from ui()
+ """
+
+ pass
+
+ def image_changed(self):
+ pass
+
+
+def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
+ try:
+ res = func(*args, **kwargs)
+ return res
+ except Exception as e:
+ errors.display(e, f"calling {filename}/{funcname}")
+
+ return default
+
+
+class ScriptPostprocessingRunner:
+ def __init__(self):
+ self.scripts = None
+ self.ui_created = False
+
+ def initialize_scripts(self, scripts_data):
+ self.scripts = []
+
+ for script_class, path, basedir, script_module in scripts_data:
+ script: ScriptPostprocessing = script_class()
+ script.filename = path
+
+ self.scripts.append(script)
+
+ def create_script_ui(self, script, inputs):
+ script.args_from = len(inputs)
+ script.args_to = len(inputs)
+
+ script.controls = wrap_call(script.ui, script.filename, "ui")
+
+ for control in script.controls.values():
+ control.custom_script_source = os.path.basename(script.filename)
+
+ inputs += list(script.controls.values())
+ script.args_to = len(inputs)
+
+ def scripts_in_preferred_order(self):
+ if self.scripts is None:
+ import modules.scripts
+ self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
+
+ scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
+
+ def script_score(name):
+ name = name.lower()
+ for i, possible_match in enumerate(scripts_order):
+ if possible_match in name:
+ return i
+
+ return len(self.scripts)
+
+ script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
+
+ return sorted(self.scripts, key=lambda x: script_scores[x.name])
+
+ def setup_ui(self):
+ inputs = []
+
+ for script in self.scripts_in_preferred_order():
+ with gr.Box() as group:
+ self.create_script_ui(script, inputs)
+
+ script.group = group
+
+ self.ui_created = True
+ return inputs
+
+ def run(self, pp: PostprocessedImage, args):
+ for script in self.scripts_in_preferred_order():
+ shared.state.job = script.name
+
+ script_args = args[script.args_from:script.args_to]
+
+ process_args = {}
+ for (name, component), value in zip(script.controls.items(), script_args):
+ process_args[name] = value
+
+ script.process(pp, **process_args)
+
+ def create_args_for_run(self, scripts_args):
+ if not self.ui_created:
+ with gr.Blocks(analytics_enabled=False):
+ self.setup_ui()
+
+ scripts = self.scripts_in_preferred_order()
+ args = [None] * max([x.args_to for x in scripts])
+
+ for script in scripts:
+ script_args_dict = scripts_args.get(script.name, None)
+ if script_args_dict is not None:
+
+ for i, name in enumerate(script.controls):
+ args[script.args_from + i] = script_args_dict.get(name, None)
+
+ return args
+
+ def image_changed(self):
+ for script in self.scripts_in_preferred_order():
+ script.image_changed()
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index c72d8efc..e90aa9fe 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -41,7 +41,9 @@ class DisableInitialization:
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
- return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ res.name_or_path = pretrained_model_name_or_path
+ return res
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 4fa54329..74452709 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared
+from modules import shared, errors
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
)
+def get_xformers_flash_attention_op(q, k, v):
+ if not shared.cmd_opts.xformers_flash_attention:
+ return None
+
+ try:
+ flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
+ fw, bw = flash_attention_op
+ if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
+ return flash_attention_op
+ except Exception as e:
+ errors.display_once(e, "enabling flash attention")
+
+ return None
+
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@@ -290,7 +305,8 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@@ -365,7 +381,7 @@ def xformers_attnblock_forward(self, x):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
- out = xformers.ops.memory_efficient_attention(q, k, v)
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
diff --git a/modules/shared.py b/modules/shared.py
index 52bbb807..5f713bee 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -57,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
+parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
@@ -398,7 +399,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
@@ -406,7 +407,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -422,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
@@ -429,6 +432,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
}))
+options_templates.update(options_section(('extra_networks', "Extra Networks"), {
+ "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }),
+}))
+
options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
@@ -443,9 +450,12 @@ options_templates.update(options_section(('ui', "User interface"), {
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
- 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
- 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
- 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
+ "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
+ "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
options_templates.update(options_section(('ui', "Live previews"), {
@@ -470,6 +480,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
}))
+options_templates.update(options_section(('postprocessing', "Postprocessing"), {
+ 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
+ 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+}))
+
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
diff --git a/modules/ui.py b/modules/ui.py
index 64180889..e701393c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,7 +5,6 @@ import mimetypes
import os
import platform
import random
-import subprocess as sp
import sys
import tempfile
import time
@@ -20,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -41,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
+import modules.extras
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
@@ -75,6 +75,7 @@ css_hide_progressbar = """
.wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
+.wrap.cover-bg .z-20::before { content:"" }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text-center { display:none!important; }
@@ -85,7 +86,6 @@ css_hide_progressbar = """
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
paste_symbol = '\u2199\ufe0f' # ↙
-folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
@@ -94,78 +94,14 @@ extra_networks_symbol = '\U0001F3B4' # 🎴
def plaintext_to_html(text):
- text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
- return text
+ return ui_common.plaintext_to_html(text)
+
def send_gradio_gallery_to_image(x):
if len(x) == 0:
return None
return image_from_url_text(x[0])
-def save_files(js_data, images, do_make_zip, index):
- import csv
- filenames = []
- fullfns = []
-
- #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
- class MyObject:
- def __init__(self, d=None):
- if d is not None:
- for key, value in d.items():
- setattr(self, key, value)
-
- data = json.loads(js_data)
-
- p = MyObject(data)
- path = opts.outdir_save
- save_to_dirs = opts.use_save_to_dirs_for_ui
- extension: str = opts.samples_format
- start_index = 0
-
- if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
- images = [images[index]]
- start_index = index
-
- os.makedirs(opts.outdir_save, exist_ok=True)
-
- with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
- at_start = file.tell() == 0
- writer = csv.writer(file)
- if at_start:
- writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
-
- for image_index, filedata in enumerate(images, start_index):
- image = image_from_url_text(filedata)
-
- is_grid = image_index < p.index_of_first_image
- i = 0 if is_grid else (image_index - p.index_of_first_image)
-
- fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
-
- filename = os.path.relpath(fullfn, path)
- filenames.append(filename)
- fullfns.append(fullfn)
- if txt_fullfn:
- filenames.append(os.path.basename(txt_fullfn))
- fullfns.append(txt_fullfn)
-
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
-
- # Make Zip
- if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
-
- from zipfile import ZipFile
- with ZipFile(zip_filepath, "w") as zip_file:
- for i in range(len(fullfns)):
- with open(fullfns[i], mode="rb") as f:
- zip_file.writestr(filenames[i], f.read())
- fullfns.insert(0, zip_filepath)
-
- return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-
-
def visit(x, func, path=""):
if hasattr(x, 'children'):
for c in x.children:
@@ -443,19 +379,6 @@ def apply_setting(key, value):
opts.save(shared.config_filename)
return getattr(opts, key)
-
-def update_generation_info(generation_info, html_info, img_index):
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info, gr.update()
- return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info, gr.update()
-
-
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@@ -476,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
def create_output_panel(tabname, outdir):
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
- with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
- inputs=[generation_info, html_info, html_info],
- outputs=[html_info, html_info],
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ],
- show_progress=False,
- )
-
- save_zip.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+ return ui_common.create_output_panel(tabname, outdir)
def create_sampler_and_steps_selection(choices, tabname):
@@ -935,7 +758,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
elif category == "checkboxes":
- with FormRow(elem_id="img2img_checkboxes"):
+ with FormRow(elem_id="img2img_checkboxes", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
@@ -1122,86 +945,7 @@ def create_ui():
modules.scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='compact'):
- with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image', elem_id="extras_single_tab"):
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
-
- with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
-
- with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
- show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
-
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
- with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
- with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
- with gr.Group():
- with gr.Row():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
-
- with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
-
- with gr.Group():
- extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
-
- with gr.Group():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
-
- with gr.Group():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
-
- with gr.Group():
- upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
-
- result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
-
- submit.click(
- fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
- _js="get_extras_tab_index",
- inputs=[
- dummy_component,
- dummy_component,
- extras_image,
- image_batch,
- extras_batch_input_dir,
- extras_batch_output_dir,
- show_extras_results,
- gfpgan_visibility,
- codeformer_visibility,
- codeformer_weight,
- upscaling_resize,
- upscaling_resize_w,
- upscaling_resize_h,
- upscaling_crop,
- extras_upscaler_1,
- extras_upscaler_2,
- extras_upscaler_2_visibility,
- upscale_before_face_fix,
- ],
- outputs=[
- result_images,
- html_info_x,
- html_info,
- ]
- )
- parameters_copypaste.add_paste_fields("extras", extras_image, None)
-
- extras_image.change(
- fn=modules.extras.clear_cache,
- inputs=[], outputs=[]
- )
+ ui_postprocessing.create_ui()
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
@@ -1222,10 +966,19 @@ def create_ui():
outputs=[html, generation_info, html2],
)
+ def update_interp_description(value):
+ interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
+ interp_descriptions = {
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
+ }
+ return interp_descriptions[value]
+
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact'):
- gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
+ interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
@@ -1240,6 +993,7 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
@@ -1254,6 +1008,9 @@ def create_ui():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
+ with FormRow():
+ discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
+
with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
@@ -1265,7 +1022,7 @@ def create_ui():
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
- with gr.Row().style(equal_height=False):
+ with gr.Row(variant="compact").style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
@@ -1844,6 +1601,7 @@ def create_ui():
checkpoint_format,
config_source,
bake_in_vae,
+ discard_weights,
],
outputs=[
primary_model_name,
@@ -1931,28 +1689,27 @@ def create_ui():
with open(ui_config_file, "w", encoding="utf8") as file:
json.dump(ui_settings, file, indent=4)
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ interp_description.value = update_interp_description(interp_method.value)
+
return demo
def reload_javascript():
- with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
- javascript = f'<script>{jsfile.read()}</script>'
-
- scripts_list = modules.scripts.list_scripts("javascript", ".js")
-
- for basedir, filename, path in scripts_list:
- with open(path, "r", encoding="utf8") as jsfile:
- javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
+ head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}"></script>\n'
+ inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
- javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
+ inline += f"set_theme('{cmd_opts.theme}');"
+
+ for script in modules.scripts.list_scripts("javascript", ".js"):
+ head += f'<script type="text/javascript" src="file={script.path}"></script>\n'
- javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
+ head += f'<script type="text/javascript">{inline}</script>\n'
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(
- b'</head>', f'{javascript}</head>'.encode("utf8"))
+ res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
res.init_headers()
return res
diff --git a/modules/ui_common.py b/modules/ui_common.py
new file mode 100644
index 00000000..9405ac1f
--- /dev/null
+++ b/modules/ui_common.py
@@ -0,0 +1,202 @@
+import json
+import html
+import os
+import platform
+import sys
+
+import gradio as gr
+import subprocess as sp
+
+from modules import call_queue, shared
+from modules.generation_parameters_copypaste import image_from_url_text
+import modules.images
+
+folder_symbol = '\U0001f4c2' # 📂
+
+
+def update_generation_info(generation_info, html_info, img_index):
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info, gr.update()
+
+
+def plaintext_to_html(text):
+ text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
+ return text
+
+
+def save_files(js_data, images, do_make_zip, index):
+ import csv
+ filenames = []
+ fullfns = []
+
+ #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
+ class MyObject:
+ def __init__(self, d=None):
+ if d is not None:
+ for key, value in d.items():
+ setattr(self, key, value)
+
+ data = json.loads(js_data)
+
+ p = MyObject(data)
+ path = shared.opts.outdir_save
+ save_to_dirs = shared.opts.use_save_to_dirs_for_ui
+ extension: str = shared.opts.samples_format
+ start_index = 0
+
+ if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+
+ images = [images[index]]
+ start_index = index
+
+ os.makedirs(shared.opts.outdir_save, exist_ok=True)
+
+ with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
+ at_start = file.tell() == 0
+ writer = csv.writer(file)
+ if at_start:
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
+
+ for image_index, filedata in enumerate(images, start_index):
+ image = image_from_url_text(filedata)
+
+ is_grid = image_index < p.index_of_first_image
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
+
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
+
+ filename = os.path.relpath(fullfn, path)
+ filenames.append(filename)
+ fullfns.append(fullfn)
+ if txt_fullfn:
+ filenames.append(os.path.basename(txt_fullfn))
+ fullfns.append(txt_fullfn)
+
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+
+ # Make Zip
+ if do_make_zip:
+ zip_filepath = os.path.join(path, "images.zip")
+
+ from zipfile import ZipFile
+ with ZipFile(zip_filepath, "w") as zip_file:
+ for i in range(len(fullfns)):
+ with open(fullfns[i], mode="rb") as f:
+ zip_file.writestr(filenames[i], f.read())
+ fullfns.insert(0, zip_filepath)
+
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
+
+
+def create_output_panel(tabname, outdir):
+ from modules import shared
+ import modules.generation_parameters_copypaste as parameters_copypaste
+
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+ open_folder_button.click(
+ fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
+ )
+
+ save.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ save_zip.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ else:
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 46324425..9aec3097 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -47,3 +47,4 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
def get_block_name(self):
return "colorpicker"
+
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index e2e060c8..8b4f97f8 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -3,6 +3,7 @@ import os.path
from modules import shared
import gradio as gr
import json
+import html
from modules.generation_parameters_copypaste import image_from_url_text
@@ -26,6 +27,7 @@ class ExtraNetworksPage:
pass
def create_html(self, tabname):
+ view = shared.opts.extra_networks_default_view
items_html = ''
for item in self.list_items():
@@ -36,7 +38,7 @@ class ExtraNetworksPage:
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
res = f"""
-<div id='{tabname}_{self.name}_cards' class='extra-network-cards'>
+<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
{items_html}
</div>
"""
@@ -53,12 +55,13 @@ class ExtraNetworksPage:
preview = item.get("preview", None)
args = {
- "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
- "prompt": json.dumps(item["prompt"]),
+ "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
+ "prompt": item["prompt"],
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
- "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+ "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
}
return self.card_page.format(**args)
@@ -79,10 +82,26 @@ class ExtraNetworksUi:
self.tabname = None
+def pages_in_preferred_order(pages):
+ tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
+
+ def tab_name_score(name):
+ name = name.lower()
+ for i, possible_match in enumerate(tab_order):
+ if possible_match in name:
+ return i
+
+ return len(pages)
+
+ tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
+
+ return sorted(pages, key=lambda x: tab_scores[x.name])
+
+
def create_ui(container, button, tabname):
ui = ExtraNetworksUi()
ui.pages = []
- ui.stored_extra_pages = extra_pages.copy()
+ ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 312dbaf0..65d000cf 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -1,3 +1,4 @@
+import json
import os
from modules import shared, ui_extra_networks
@@ -25,7 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
"name": name,
"filename": path,
"preview": preview,
- "prompt": f"<hypernet:{name}:1.0>",
+ "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index e4a6e3bf..dbd23d2d 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,3 +1,4 @@
+import json
import os
from modules import ui_extra_networks, sd_hijack
@@ -24,7 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
"name": embedding.name,
"filename": embedding.filename,
"preview": preview,
- "prompt": embedding.name,
+ "prompt": json.dumps(embedding.name),
"local_preview": path + ".preview.png",
}
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
new file mode 100644
index 00000000..b418d955
--- /dev/null
+++ b/modules/ui_postprocessing.py
@@ -0,0 +1,57 @@
+import gradio as gr
+from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+import modules.generation_parameters_copypaste as parameters_copypaste
+
+
+def create_ui():
+ tab_index = gr.State(value=0)
+
+ with gr.Row().style(equal_height=False, variant='compact'):
+ with gr.Column(variant='compact'):
+ with gr.Tabs(elem_id="mode_extras"):
+ with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
+
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
+
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+
+ script_inputs = scripts.scripts_postproc.setup_ui()
+
+ with gr.Column():
+ result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
+
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
+
+ submit.click(
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
+ inputs=[
+ tab_index,
+ extras_image,
+ image_batch,
+ extras_batch_input_dir,
+ extras_batch_output_dir,
+ show_extras_results,
+ *script_inputs
+ ],
+ outputs=[
+ result_images,
+ html_info_x,
+ html_info,
+ ]
+ )
+
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
+
+ extras_image.change(
+ fn=scripts.scripts_postproc.image_changed,
+ inputs=[], outputs=[]
+ )