aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/autocrop.py341
-rw-r--r--modules/textual_inversion/dataset.py169
-rw-r--r--modules/textual_inversion/image_embedding.py220
-rw-r--r--modules/textual_inversion/learn_schedule.py76
-rw-r--r--modules/textual_inversion/preprocess.py239
-rw-r--r--modules/textual_inversion/test_embedding.pngbin0 -> 489220 bytes
-rw-r--r--modules/textual_inversion/textual_inversion.py368
-rw-r--r--modules/textual_inversion/ui.py15
8 files changed, 1227 insertions, 201 deletions
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
new file mode 100644
index 00000000..9859974a
--- /dev/null
+++ b/modules/textual_inversion/autocrop.py
@@ -0,0 +1,341 @@
+import cv2
+import requests
+import os
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
+
+ focus = focal_point(im_debug, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im_debug)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
+ if settings.destop_view_image:
+ im_debug.show()
+
+ return results
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
+
+ pois = []
+
+ weight_pref_total = 0
+ if len(corner_points) > 0:
+ weight_pref_total += settings.corner_points_weight
+ if len(entropy_points) > 0:
+ weight_pref_total += settings.entropy_points_weight
+ if len(face_points) > 0:
+ weight_pref_total += settings.face_points_weight
+
+ corner_centroid = None
+ if len(corner_points) > 0:
+ corner_centroid = centroid(corner_points)
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ pois.append(corner_centroid)
+
+ entropy_centroid = None
+ if len(entropy_points) > 0:
+ entropy_centroid = centroid(entropy_points)
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+ pois.append(entropy_centroid)
+
+ face_centroid = None
+ if len(face_points) > 0:
+ face_centroid = centroid(face_points)
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
+ pois.append(face_centroid)
+
+ average_point = poi_average(pois, settings)
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ max_size = min(im.width, im.height) * 0.07
+ if corner_centroid is not None:
+ color = BLUE
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
+ d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(corner_points) > 1:
+ for f in corner_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if entropy_centroid is not None:
+ color = "#ff0"
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+ d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(entropy_points) > 1:
+ for f in entropy_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if face_centroid is not None:
+ color = RED
+ box = face_centroid.bounding(max_size * face_centroid.weight)
+ d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(face_points) > 1:
+ for f in face_points:
+ d.rectangle(f.bounding(4), outline=color)
+
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ if settings.dnn_model_path is not None:
+ detector = cv2.FaceDetectorYN.create(
+ settings.dnn_model_path,
+ "",
+ (im.width, im.height),
+ 0.9, # score threshold
+ 0.3, # nms threshold
+ 5000 # keep top k before nms
+ )
+ faces = detector.detect(np.array(im))
+ results = []
+ if faces[1] is not None:
+ for face in faces[1]:
+ x = face[0]
+ y = face[1]
+ w = face[2]
+ h = face[3]
+ results.append(
+ PointOfInterest(
+ int(x + (w * 0.5)), # face focus left/right is center
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
+ size = w,
+ weight = 1/len(faces[1])
+ )
+ )
+ return results
+ else:
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+ for t in tries:
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
+ return []
+
+
+def image_corner_points(im, settings):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e > e_max):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+def centroid(pois):
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+
+
+def poi_average(pois, settings):
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
+ avg_x = round(x / weight)
+ avg_y = round(y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+def is_landscape(w, h):
+ return w > h
+
+
+def is_portrait(w, h):
+ return h > w
+
+
+def is_square(w, h):
+ return w == h
+
+
+def download_and_cache_models(dirname):
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
+
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
+
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path \ No newline at end of file
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..f470324a 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -3,23 +3,36 @@ import numpy as np
import PIL
import torch
from PIL import Image
-from torch.utils.data import Dataset
+from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
-re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
+ self.filename = filename
+ self.filename_text = filename_text
+ self.latent_dist = latent_dist
+ self.latent_sample = latent_sample
+ self.cond = cond
+ self.cond_text = cond_text
+ self.pixel_values = pixel_values
-class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+class PersonalizedBase(Dataset):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
+
self.placeholder_token = placeholder_token
- self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -32,50 +45,128 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+
+
+ self.shuffle_tags = shuffle_tags
+ self.tag_drop_out = tag_drop_out
+
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
-
+ if shared.state.interrupted:
+ raise Exception("inturrupted")
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
+
+ text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0]
- filename_tokens = re_tag.findall(filename_tokens)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
- torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
- torchdata = torch.moveaxis(torchdata, 2, 0)
-
- init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
- init_latent = init_latent.to(devices.cpu)
-
- self.dataset.append((init_latent, filename_tokens))
-
- self.length = len(self.dataset) * repeats
-
- self.initial_indexes = np.arange(self.length) % len(self.dataset)
- self.indexes = None
- self.shuffle()
-
- def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
+ latent_sample = None
+
+ with torch.autocast("cuda"):
+ latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
+
+ if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+ latent_sampling_method = "once"
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ elif latent_sampling_method == "deterministic":
+ # Works only for DiagonalGaussianDistribution
+ latent_dist.std = 0
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ elif latent_sampling_method == "random":
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
+
+ if not (self.tag_drop_out != 0 or self.shuffle_tags):
+ entry.cond_text = self.create_text(filename_text)
+
+ if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
+ with torch.autocast("cuda"):
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
+
+ self.dataset.append(entry)
+ del torchdata
+ del latent_dist
+ del latent_sample
+
+ self.length = len(self.dataset)
+ assert self.length > 0, "No images have been found in the dataset."
+ self.batch_size = min(batch_size, self.length)
+ self.gradient_step = min(gradient_step, self.length // self.batch_size)
+ self.latent_sampling_method = latent_sampling_method
+
+ def create_text(self, filename_text):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ tags = filename_text.split(',')
+ if self.tag_drop_out != 0:
+ tags = [t for t in tags if random.random() > self.tag_drop_out]
+ if self.shuffle_tags:
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
+ return text
def __len__(self):
return self.length
def __getitem__(self, i):
- if i % len(self.dataset) == 0:
- self.shuffle()
-
- index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
-
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
-
- return x, text
+ entry = self.dataset[i]
+ if self.tag_drop_out != 0 or self.shuffle_tags:
+ entry.cond_text = self.create_text(entry.filename_text)
+ if self.latent_sampling_method == "random":
+ entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
+ return entry
+
+class PersonalizedDataLoader(DataLoader):
+ def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
+ super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ if latent_sampling_method == "random":
+ self.collate_fn = collate_wrapper_random
+ else:
+ self.collate_fn = collate_wrapper
+
+
+class BatchLoader:
+ def __init__(self, data):
+ self.cond_text = [entry.cond_text for entry in data]
+ self.cond = [entry.cond for entry in data]
+ self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ #self.emb_index = [entry.emb_index for entry in data]
+ #print(self.latent_sample.device)
+
+ def pin_memory(self):
+ self.latent_sample = self.latent_sample.pin_memory()
+ return self
+
+def collate_wrapper(batch):
+ return BatchLoader(batch)
+
+class BatchLoaderRandom(BatchLoader):
+ def __init__(self, data):
+ super().__init__(data)
+
+ def pin_memory(self):
+ return self
+
+def collate_wrapper_random(batch):
+ return BatchLoaderRandom(batch) \ No newline at end of file
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
new file mode 100644
index 00000000..ea653806
--- /dev/null
+++ b/modules/textual_inversion/image_embedding.py
@@ -0,0 +1,220 @@
+import base64
+import json
+import numpy as np
+import zlib
+from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
+from fonts.ttf import Roboto
+import torch
+from modules.shared import opts
+
+
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, obj)
+
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+
+ def object_hook(self, d):
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
+ return d
+
+
+def embedding_to_b64(data):
+ d = json.dumps(data, cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+
+def embedding_from_b64(data):
+ d = base64.b64decode(data)
+ return json.loads(d, cls=EmbeddingDecoder)
+
+
+def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
+ while True:
+ seed = (a * seed + c) % m
+ yield seed % 255
+
+
+def xor_block(block):
+ g = lcg()
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
+
+
+def style_block(block, sequence):
+ im = Image.new('RGB', (block.shape[1], block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i = 0
+ for x in range(-6, im.size[0], 8):
+ for yi, y in enumerate(range(-6, im.size[1], 8)):
+ offset = 0
+ if yi % 2 == 0:
+ offset = 4
+ shade = sequence[i % len(sequence)]
+ i += 1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+
+ return block ^ fg
+
+
+def insert_image_data_embed(image, data):
+ d = 3
+ data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
+ data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
+ data_np_high = data_np_ >> 4
+ data_np_low = data_np_ & 0x0F
+
+ h = image.size[1]
+ next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
+ next_size = next_size + ((h*d)-(next_size % (h*d)))
+
+ data_np_low.resize(next_size)
+ data_np_low = data_np_low.reshape((h, -1, d))
+
+ data_np_high.resize(next_size)
+ data_np_high = data_np_high.reshape((h, -1, d))
+
+ edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
+
+ data_np_low = style_block(data_np_low, sequence=edge_style)
+ data_np_low = xor_block(data_np_low)
+ data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
+ data_np_high = xor_block(data_np_high)
+
+ im_low = Image.fromarray(data_np_low, mode='RGB')
+ im_high = Image.fromarray(data_np_high, mode='RGB')
+
+ background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
+ background.paste(im_low, (0, 0))
+ background.paste(image, (im_low.size[0]+1, 0))
+ background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
+
+ return background
+
+
+def crop_black(img, tol=0):
+ mask = (img > tol).all(2)
+ mask0, mask1 = mask.any(0), mask.any(1)
+ col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
+ row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end, col_start:col_end]
+
+
+def extract_image_data_embed(image):
+ d = 3
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
+ if black_cols[0].shape[0] < 2:
+ print('No Image data blocks found.')
+ return None
+
+ data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
+ data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
+
+ data_block_lower = xor_block(data_block_lower)
+ data_block_upper = xor_block(data_block_upper)
+
+ data_block = (data_block_upper << 4) | (data_block_lower)
+ data_block = data_block.flatten().tobytes()
+
+ data = zlib.decompress(data_block)
+ return json.loads(data, cls=EmbeddingDecoder)
+
+
+def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
+ from math import cos
+
+ image = srcimage.copy()
+ fontsize = 32
+ if textfont is None:
+ try:
+ textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
+ textfont = opts.font or Roboto
+ except Exception:
+ textfont = Roboto
+
+ factor = 1.5
+ gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
+ for y in range(image.size[1]):
+ mag = 1-cos(y/image.size[1]*factor)
+ mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
+ gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
+ image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
+
+ draw = ImageDraw.Draw(image)
+
+ font = ImageFont.truetype(textfont, fontsize)
+ padding = 10
+
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
+ font = ImageFont.truetype(textfont, fontsize)
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
+
+ _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
+ fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
+ fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
+ fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+
+ font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
+
+ draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
+
+ return image
+
+
+if __name__ == '__main__':
+
+ testEmbed = Image.open('test_embedding.png')
+ data = extract_image_data_embed(testEmbed)
+ assert data is not None
+
+ data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
+ assert data is not None
+
+ image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
+ cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
+
+ test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
+
+ embedded_image = insert_image_data_embed(cap_image, test_embed)
+
+ retrived_embed = extract_image_data_embed(embedded_image)
+
+ assert str(retrived_embed) == str(test_embed)
+
+ embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
+
+ assert embedded_image == embedded_image2
+
+ g = lcg()
+ shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
+
+ reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
+ 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
+ 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
+ 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
+ 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
+ 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
+ 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
+ 204, 86, 73, 222, 44, 198, 118, 240, 97]
+
+ assert shared_random == reference_random
+
+ hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
+
+ assert 12731374 == hunna_kay_random_sum
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..dd0c0ad1
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,76 @@
+import tqdm
+
+
+class LearnScheduleIterator:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
+ """
+
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number < self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..56b9b2eb 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,104 +1,195 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
+import time
-from modules import shared, images
+from modules import shared, images, deepbooru
+from modules.paths import models_path
+from modules.shared import opts, cmd_opts
+from modules.textual_inversion import autocrop
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
- src = os.path.abspath(process_src)
- dst = os.path.abspath(process_dst)
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+ try:
+ if process_caption:
+ shared.interrogator.load()
- assert src != dst, 'same directory specified as source and destination'
+ if process_caption_deepbooru:
+ deepbooru.model.start()
- os.makedirs(dst, exist_ok=True)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
- files = os.listdir(src)
+ finally:
- shared.state.textinfo = "Preprocessing..."
- shared.state.job_count = len(files)
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
- if process_caption:
- shared.interrogator.load()
+ if process_caption_deepbooru:
+ deepbooru.model.stop()
- def save_pic_with_caption(image, index):
- if process_caption:
- caption = "-" + shared.interrogator.generate_caption(image)
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- else:
- caption = filename
- caption = os.path.splitext(caption)[0]
- caption = os.path.basename(caption)
- image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
- subindex[0] += 1
+def listfiles(dirname):
+ return os.listdir(dirname)
- def save_pic(image, index):
- save_pic_with_caption(image, index)
- if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+class PreprocessParams:
+ src = None
+ dstdir = None
+ subindex = 0
+ flip = False
+ process_caption = False
+ process_caption_deepbooru = False
+ preprocess_txt_action = None
- for index, imagefile in enumerate(tqdm.tqdm(files)):
- subindex = [0]
- filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
- if shared.state.interrupted:
- break
+def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
+ caption = ""
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
+ if params.process_caption:
+ caption += shared.interrogator.generate_caption(image)
- if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ if params.process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.model.tag_multi(image)
- top = img.crop((0, 0, size, size))
- save_pic(top, index)
+ filename_part = params.src
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
- bot = img.crop((0, img.height - size, size, img.height))
- save_pic(bot, index)
- elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ basename = f"{index:05}-{params.subindex}-{filename_part}"
+ image.save(os.path.join(params.dstdir, f"{basename}.png"))
- left = img.crop((0, 0, size, size))
- save_pic(left, index)
+ if params.preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif params.preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif params.preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
- right = img.crop((img.width - size, 0, img.width, size))
- save_pic(right, index)
- else:
- img = images.resize_image(1, img, size, size)
- save_pic(img, index)
+ caption = caption.strip()
+
+ if len(caption) > 0:
+ with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
+
+ params.subindex += 1
- shared.state.nextjob()
- if process_caption:
- shared.interrogator.send_blip_to_ram()
+def save_pic(image, index, params, existing_caption=None):
+ save_pic_with_caption(image, index, params, existing_caption=existing_caption)
-def sanitize_caption(base_path, original_caption, suffix):
- operating_system = platform.system().lower()
- if (operating_system == "windows"):
- invalid_path_characters = "\\/:*?\"<>|"
- max_path_length = 259
+ if params.flip:
+ save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
else:
- invalid_path_characters = "/" #linux/macos
- max_path_length = 1023
- caption = original_caption
- for invalid_character in invalid_path_characters:
- caption = caption.replace(invalid_character, "")
- fixed_path_length = len(base_path) + len(suffix)
- if fixed_path_length + len(caption) <= max_path_length:
- return caption
- caption_tokens = caption.split()
- new_caption = ""
- for token in caption_tokens:
- last_caption = new_caption
- new_caption = new_caption + token + " "
- if (len(new_caption) + fixed_path_length - 1 > max_path_length):
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+ width = process_width
+ height = process_height
+ src = os.path.abspath(process_src)
+ dst = os.path.abspath(process_dst)
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
+
+ assert src != dst, 'same directory specified as source and destination'
+
+ os.makedirs(dst, exist_ok=True)
+
+ files = listfiles(src)
+
+ shared.state.textinfo = "Preprocessing..."
+ shared.state.job_count = len(files)
+
+ params = PreprocessParams()
+ params.dstdir = dst
+ params.flip = process_flip
+ params.process_caption = process_caption
+ params.process_caption_deepbooru = process_caption_deepbooru
+ params.preprocess_txt_action = preprocess_txt_action
+
+ for index, imagefile in enumerate(tqdm.tqdm(files)):
+ params.subindex = 0
+ filename = os.path.join(src, imagefile)
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
+
+ params.src = filename
+
+ existing_caption = None
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
+
+ if shared.state.interrupted:
break
- print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
- return last_caption.strip()
+
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
+
+ process_default_resize = True
+
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
+ save_pic(splitted, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_focal_crop and img.height != img.width:
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+ except Exception as e:
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+
+ autocrop_settings = autocrop.Settings(
+ crop_width = width,
+ crop_height = height,
+ face_points_weight = process_focal_crop_face_weight,
+ entropy_points_weight = process_focal_crop_entropy_weight,
+ corner_points_weight = process_focal_crop_edges_weight,
+ annotate_image = process_focal_crop_debug,
+ dnn_model_path = dnn_model_path,
+ )
+ for focal in autocrop.crop_image(img, autocrop_settings):
+ save_pic(focal, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_default_resize:
+ img = images.resize_image(1, img, width, height)
+ save_pic(img, index, params, existing_caption=existing_caption)
+
+ shared.state.nextjob()
diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png
new file mode 100644
index 00000000..07e2d9af
--- /dev/null
+++ b/modules/textual_inversion/test_embedding.png
Binary files differ
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..4eb75cb5 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,11 +6,17 @@ import torch
import tqdm
import html
import datetime
+import csv
+from PIL import Image, PngImagePlugin
-from modules import shared, devices, sd_hijack, processing, sd_models
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
+ insert_image_data_embed, extract_image_data_embed,
+ caption_image_overlay)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -58,7 +64,8 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
- ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+ # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
+ ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
@@ -80,7 +87,18 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
+ else:
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -102,7 +120,7 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
@@ -120,6 +138,7 @@ class EmbeddingDatabase:
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ print("Embeddings:", ', '.join(self.word_embeddings.keys()))
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -135,19 +154,23 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
- embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
- ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+
+ embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
@@ -156,8 +179,56 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
- assert embedding_name, 'embedding not selected'
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ epoch = (step - 1) // epoch_len
+ epoch_step = (step - 1) % epoch_len
+
+ csv_writer.writerow({
+ "step": step,
+ "epoch": epoch,
+ "epoch_step": epoch_step,
+ **values,
+ })
+
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
+ assert gradient_step > 0, "Gradient accumulation step must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_file, "Prompt template file is empty"
+ assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0 , "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0 , "Create image must be positive or 0"
+ if save_model_every or create_image_every:
+ assert log_directory, "Log directory is empty"
+
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -165,6 +236,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+ unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
@@ -178,94 +250,224 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
else:
images_dir = None
- cond_model = shared.sd_model.cond_stage_model
-
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
- embedding.vec.requires_grad = True
-
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
-
- losses = torch.zeros((32,))
-
- last_saved_file = "<none>"
- last_saved_image = "<none>"
+ checkpoint = sd_models.select_checkpoint()
- ititial_step = embedding.step or 0
- if ititial_step > steps:
+ initial_step = embedding.step or 0
+ if initial_step >= steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
+ scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text) in pbar:
- embedding.step = i + ititial_step
-
- if embedding.step > steps:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([text])
-
- x = x.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- pbar.set_description(f"loss: {losses.mean():.7f}")
-
- if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
- embedding.save(last_saved_file)
-
- if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- prompt=text,
- steps=20,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
-
- processed = processing.process_images(p)
- image = processed.images[0]
+ # dataset loading may take a while, so input validations and early returns should be done before this
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+
+ pin_memory = shared.opts.pin_memory
+
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
- shared.state.current_image = image
- image.save(last_saved_image)
+ latent_sampling_method = ds.latent_sampling_method
- last_saved_image += f", prompt: {text}"
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
- shared.state.job_no = embedding.step
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.textinfo = f"""
+ embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ scaler = torch.cuda.amp.GradScaler()
+
+ batch_size = ds.batch_size
+ gradient_step = ds.gradient_step
+ # n steps = batch_size * gradient_step * n image processed
+ steps_per_epoch = len(ds) // batch_size // gradient_step
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+ loss_step = 0
+ _loss_step = 0 #internal
+
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+ forced_filename = "<none>"
+ embedding_yet_to_be_embedded = False
+
+ pbar = tqdm.tqdm(total=steps - initial_step)
+ try:
+ for i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ # c = stack_conds(batch.cond).to(devices.device)
+ # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
+ # print(mask)
+ # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ c = shared.sd_model.cond_stage_model(batch.cond_text)
+ loss = shared.sd_model(x, c)[0] / gradient_step
+ del x
+
+ _loss_step += loss.item()
+ scaler.scale(loss).backward()
+
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+ scaler.step(optimizer)
+ scaler.update()
+ embedding.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // steps_per_epoch
+ epoch_step = embedding.step % steps_per_epoch
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ #if shared.opts.save_optimizer_state:
+ #embedding.optimizer_state_dict = optimizer.state_dict()
+ save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
+ "loss": f"{loss_step:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
+
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
<p>
-Loss: {losses.mean():.7f}<br/>
-Step: {embedding.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Loss: {loss_step:.7f}<br/>
+Step: {steps_done}<br/>
+Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
-
- checkpoint = sd_models.select_checkpoint()
-
- embedding.sd_checkpoint = checkpoint.hash
- embedding.sd_checkpoint_name = checkpoint.model_name
- embedding.cached_checksum = None
- embedding.save(filename)
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr)
+ pass
+ finally:
+ pbar.leave = False
+ pbar.close()
+ shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename
+def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.save(filename)
+ except:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index f19ac5e0..35c4feef 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -18,13 +18,17 @@ def create_embedding(name, initialization_text, nvpt):
def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args)
- return "Preprocessing finished.", ""
+ return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args):
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+
+ apply_optimizations = shared.opts.training_xattention_optimizations
try:
- sd_hijack.undo_optimizations()
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
@@ -36,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception:
raise
finally:
- sd_hijack.apply_optimizations()
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()