aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/preprocess.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/preprocess.py')
-rw-r--r--modules/textual_inversion/preprocess.py239
1 files changed, 165 insertions, 74 deletions
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..56b9b2eb 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,104 +1,195 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
+import time
-from modules import shared, images
+from modules import shared, images, deepbooru
+from modules.paths import models_path
+from modules.shared import opts, cmd_opts
+from modules.textual_inversion import autocrop
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
- src = os.path.abspath(process_src)
- dst = os.path.abspath(process_dst)
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+ try:
+ if process_caption:
+ shared.interrogator.load()
- assert src != dst, 'same directory specified as source and destination'
+ if process_caption_deepbooru:
+ deepbooru.model.start()
- os.makedirs(dst, exist_ok=True)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
- files = os.listdir(src)
+ finally:
- shared.state.textinfo = "Preprocessing..."
- shared.state.job_count = len(files)
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
- if process_caption:
- shared.interrogator.load()
+ if process_caption_deepbooru:
+ deepbooru.model.stop()
- def save_pic_with_caption(image, index):
- if process_caption:
- caption = "-" + shared.interrogator.generate_caption(image)
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- else:
- caption = filename
- caption = os.path.splitext(caption)[0]
- caption = os.path.basename(caption)
- image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
- subindex[0] += 1
+def listfiles(dirname):
+ return os.listdir(dirname)
- def save_pic(image, index):
- save_pic_with_caption(image, index)
- if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+class PreprocessParams:
+ src = None
+ dstdir = None
+ subindex = 0
+ flip = False
+ process_caption = False
+ process_caption_deepbooru = False
+ preprocess_txt_action = None
- for index, imagefile in enumerate(tqdm.tqdm(files)):
- subindex = [0]
- filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
- if shared.state.interrupted:
- break
+def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
+ caption = ""
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
+ if params.process_caption:
+ caption += shared.interrogator.generate_caption(image)
- if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ if params.process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.model.tag_multi(image)
- top = img.crop((0, 0, size, size))
- save_pic(top, index)
+ filename_part = params.src
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
- bot = img.crop((0, img.height - size, size, img.height))
- save_pic(bot, index)
- elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ basename = f"{index:05}-{params.subindex}-{filename_part}"
+ image.save(os.path.join(params.dstdir, f"{basename}.png"))
- left = img.crop((0, 0, size, size))
- save_pic(left, index)
+ if params.preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif params.preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif params.preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
- right = img.crop((img.width - size, 0, img.width, size))
- save_pic(right, index)
- else:
- img = images.resize_image(1, img, size, size)
- save_pic(img, index)
+ caption = caption.strip()
+
+ if len(caption) > 0:
+ with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
+
+ params.subindex += 1
- shared.state.nextjob()
- if process_caption:
- shared.interrogator.send_blip_to_ram()
+def save_pic(image, index, params, existing_caption=None):
+ save_pic_with_caption(image, index, params, existing_caption=existing_caption)
-def sanitize_caption(base_path, original_caption, suffix):
- operating_system = platform.system().lower()
- if (operating_system == "windows"):
- invalid_path_characters = "\\/:*?\"<>|"
- max_path_length = 259
+ if params.flip:
+ save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
else:
- invalid_path_characters = "/" #linux/macos
- max_path_length = 1023
- caption = original_caption
- for invalid_character in invalid_path_characters:
- caption = caption.replace(invalid_character, "")
- fixed_path_length = len(base_path) + len(suffix)
- if fixed_path_length + len(caption) <= max_path_length:
- return caption
- caption_tokens = caption.split()
- new_caption = ""
- for token in caption_tokens:
- last_caption = new_caption
- new_caption = new_caption + token + " "
- if (len(new_caption) + fixed_path_length - 1 > max_path_length):
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+ width = process_width
+ height = process_height
+ src = os.path.abspath(process_src)
+ dst = os.path.abspath(process_dst)
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
+
+ assert src != dst, 'same directory specified as source and destination'
+
+ os.makedirs(dst, exist_ok=True)
+
+ files = listfiles(src)
+
+ shared.state.textinfo = "Preprocessing..."
+ shared.state.job_count = len(files)
+
+ params = PreprocessParams()
+ params.dstdir = dst
+ params.flip = process_flip
+ params.process_caption = process_caption
+ params.process_caption_deepbooru = process_caption_deepbooru
+ params.preprocess_txt_action = preprocess_txt_action
+
+ for index, imagefile in enumerate(tqdm.tqdm(files)):
+ params.subindex = 0
+ filename = os.path.join(src, imagefile)
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
+
+ params.src = filename
+
+ existing_caption = None
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
+
+ if shared.state.interrupted:
break
- print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
- return last_caption.strip()
+
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
+
+ process_default_resize = True
+
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
+ save_pic(splitted, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_focal_crop and img.height != img.width:
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+ except Exception as e:
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+
+ autocrop_settings = autocrop.Settings(
+ crop_width = width,
+ crop_height = height,
+ face_points_weight = process_focal_crop_face_weight,
+ entropy_points_weight = process_focal_crop_entropy_weight,
+ corner_points_weight = process_focal_crop_edges_weight,
+ annotate_image = process_focal_crop_debug,
+ dnn_model_path = dnn_model_path,
+ )
+ for focal in autocrop.crop_image(img, autocrop_settings):
+ save_pic(focal, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_default_resize:
+ img = images.resize_image(1, img, width, height)
+ save_pic(img, index, params, existing_caption=existing_caption)
+
+ shared.state.nextjob()