aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/preprocess.py38
-rw-r--r--modules/textual_inversion/textual_inversion.py8
2 files changed, 41 insertions, 5 deletions
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 64abff4d..c0ac11d3 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
@@ -20,7 +20,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height,
if process_caption_deepbooru:
deepbooru.model.start()
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
finally:
@@ -109,8 +109,30 @@ def split_pic(image, inverse_xy, width, height, overlap_ratio):
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
-
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
+def center_crop(image: Image, w: int, h: int):
+ iw, ih = image.size
+ if ih / h < iw / w:
+ sw = w * ih / h
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+ else:
+ sh = h * iw / w
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
+
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+ iw, ih = image.size
+ err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
+ wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
+ default=None
+ )
+ return wh and center_crop(image, *wh)
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -194,6 +216,14 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False
+ if process_multicrop:
+ cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+ if cropped is not None:
+ save_pic(cropped, index, params, existing_caption=existing_caption)
+ else:
+ print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
+ process_default_resize = False
+
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, params, existing_caption=existing_caption)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7e4a6d24..4e90f690 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -15,7 +15,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -50,6 +50,7 @@ class Embedding:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
+ self.filename = None
def save(self, filename):
embedding_data = {
@@ -182,6 +183,7 @@ class EmbeddingDatabase:
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
+ embedding.filename = path
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -452,6 +454,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -617,9 +621,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
return embedding, filename
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None