aboutsummaryrefslogtreecommitdiff
path: root/modules/images.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/images.py')
-rw-r--r--modules/images.py96
1 files changed, 68 insertions, 28 deletions
diff --git a/modules/images.py b/modules/images.py
index 40efc96c..b5412548 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,6 +1,6 @@
+from __future__ import annotations
+
import datetime
-import sys
-import traceback
import pytz
import io
@@ -12,7 +12,7 @@ import re
import numpy as np
import piexif
import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
import string
import json
import hashlib
@@ -21,6 +21,8 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
+import modules.sd_vae as sd_vae
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -139,6 +141,11 @@ class GridAnnotation:
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+
+ color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
+ color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
+ color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
+
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
@@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
fnt = get_font(fontsize)
- color_active = (0, 0, 0)
- color_inactive = (153, 153, 153)
-
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
cols = im.width // width
@@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
- calc_img = Image.new("RGB", (1, 1), "white")
+ calc_img = Image.new("RGB", (1, 1), color_background)
calc_d = ImageDraw.Draw(calc_img)
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
for row in range(rows):
for col in range(cols):
@@ -336,8 +340,20 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator:
+ def get_vae_filename(self): #get the name of the VAE file.
+ if sd_vae.loaded_vae_file is None:
+ return "NoneType"
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
+ split_file_name = file_name.split('.')
+ if len(split_file_name) > 1 and split_file_name[0] == '':
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
+ else:
+ return split_file_name[0]
+
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.image.width,
@@ -354,19 +370,23 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
- 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
- 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
+ 'batch_size': lambda self: self.p.batch_size,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'user': lambda self: self.p.user,
+ 'vae_filename': lambda self: self.get_vae_filename(),
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt, image):
+ def __init__(self, p, seed, prompt, image, zip=False):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
+ self.zip = zip
def hasprompt(self, *args):
lower = self.prompt.lower()
@@ -390,7 +410,7 @@ class FilenameGenerator:
prompt_no_style = self.prompt
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
- if len(style) > 0:
+ if style:
for part in style.split("{prompt}"):
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
@@ -399,7 +419,7 @@ class FilenameGenerator:
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
def prompt_words(self):
- words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
+ words = [x for x in re_nonletters.split(self.prompt or "") if x]
if len(words) == 0:
words = ["empty"]
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
@@ -407,7 +427,7 @@ class FilenameGenerator:
def datetime(self, *args):
time_datetime = datetime.datetime.now()
- time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
+ time_format = args[0] if (args and args[0] != "") else self.default_time_format
try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError:
@@ -446,8 +466,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -482,13 +501,23 @@ def get_next_sequence_number(path, basename):
return result + 1
-def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+ """
+ Saves image to filename, including geninfo as text information for generation info.
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+ """
+
if extension is None:
extension = os.path.splitext(filename)[1]
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png':
+ existing_pnginfo = existing_pnginfo or {}
+ if opts.enable_pnginfo:
+ existing_pnginfo[pnginfo_section_name] = geninfo
+
if opts.enable_pnginfo:
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in (existing_pnginfo or {}).items():
@@ -607,7 +636,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
temp_file_path = f"{filename_without_extension}.tmp"
- save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
os.replace(temp_file_path, filename_without_extension + extension)
@@ -624,12 +653,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
-
+ resize_to = None
if oversize and ratio > 1:
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
elif oversize:
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
+ if resize_to is not None:
+ try:
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+ image = image.resize(resize_to, LANCZOS)
+ except Exception:
+ image = image.resize(resize_to)
try:
_atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e:
@@ -647,8 +682,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn
-def read_info_from_image(image):
- items = image.info or {}
+IGNORED_INFO_KEYS = {
+ 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity', 'photoshop',
+}
+
+
+def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
+ items = (image.info or {}).copy()
geninfo = items.pop('parameters', None)
@@ -664,9 +706,8 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
+ for field in IGNORED_INFO_KEYS:
+ items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
@@ -677,8 +718,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items