aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/dataset.py82
-rw-r--r--modules/textual_inversion/image_embedding.py220
-rw-r--r--modules/textual_inversion/learn_schedule.py69
-rw-r--r--modules/textual_inversion/preprocess.py171
-rw-r--r--modules/textual_inversion/test_embedding.pngbin0 -> 489220 bytes
-rw-r--r--modules/textual_inversion/textual_inversion.py143
-rw-r--r--modules/textual_inversion/ui.py6
7 files changed, 585 insertions, 106 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..5b1c5002 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,18 +8,28 @@ from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
-re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, latent=None, filename_text=None):
+ self.filename = filename
+ self.latent = latent
+ self.filename_text = filename_text
+ self.cond = None
+ self.cond_text = None
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
- self.size = size
+ self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -33,16 +43,28 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
+ cond_model = shared.sd_model.cond_stage_model
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
+ text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0]
- filename_tokens = re_tag.findall(filename_tokens)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@@ -53,29 +75,47 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens))
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
- self.length = len(self.dataset) * repeats
+ if include_cond:
+ entry.cond_text = self.create_text(filename_text)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
- self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.dataset.append(entry)
+
+ assert len(self.dataset) > 0, "No images have been found in the dataset."
+ self.length = len(self.dataset) * repeats // batch_size
+
+ self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
+
+ def create_text(self, filename_text):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", filename_text)
+ return text
def __len__(self):
return self.length
def __getitem__(self, i):
- if i % len(self.dataset) == 0:
- self.shuffle()
+ res = []
- index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
+ for j in range(self.batch_size):
+ position = i * self.batch_size + j
+ if position % len(self.indexes) == 0:
+ self.shuffle()
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ index = self.indexes[position % len(self.indexes)]
+ entry = self.dataset[index]
+
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
+
+ res.append(entry)
- return x, text
+ return res
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
new file mode 100644
index 00000000..ea653806
--- /dev/null
+++ b/modules/textual_inversion/image_embedding.py
@@ -0,0 +1,220 @@
+import base64
+import json
+import numpy as np
+import zlib
+from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
+from fonts.ttf import Roboto
+import torch
+from modules.shared import opts
+
+
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, obj)
+
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+
+ def object_hook(self, d):
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
+ return d
+
+
+def embedding_to_b64(data):
+ d = json.dumps(data, cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+
+def embedding_from_b64(data):
+ d = base64.b64decode(data)
+ return json.loads(d, cls=EmbeddingDecoder)
+
+
+def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
+ while True:
+ seed = (a * seed + c) % m
+ yield seed % 255
+
+
+def xor_block(block):
+ g = lcg()
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
+
+
+def style_block(block, sequence):
+ im = Image.new('RGB', (block.shape[1], block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i = 0
+ for x in range(-6, im.size[0], 8):
+ for yi, y in enumerate(range(-6, im.size[1], 8)):
+ offset = 0
+ if yi % 2 == 0:
+ offset = 4
+ shade = sequence[i % len(sequence)]
+ i += 1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+
+ return block ^ fg
+
+
+def insert_image_data_embed(image, data):
+ d = 3
+ data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
+ data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
+ data_np_high = data_np_ >> 4
+ data_np_low = data_np_ & 0x0F
+
+ h = image.size[1]
+ next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
+ next_size = next_size + ((h*d)-(next_size % (h*d)))
+
+ data_np_low.resize(next_size)
+ data_np_low = data_np_low.reshape((h, -1, d))
+
+ data_np_high.resize(next_size)
+ data_np_high = data_np_high.reshape((h, -1, d))
+
+ edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
+
+ data_np_low = style_block(data_np_low, sequence=edge_style)
+ data_np_low = xor_block(data_np_low)
+ data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
+ data_np_high = xor_block(data_np_high)
+
+ im_low = Image.fromarray(data_np_low, mode='RGB')
+ im_high = Image.fromarray(data_np_high, mode='RGB')
+
+ background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
+ background.paste(im_low, (0, 0))
+ background.paste(image, (im_low.size[0]+1, 0))
+ background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
+
+ return background
+
+
+def crop_black(img, tol=0):
+ mask = (img > tol).all(2)
+ mask0, mask1 = mask.any(0), mask.any(1)
+ col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
+ row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end, col_start:col_end]
+
+
+def extract_image_data_embed(image):
+ d = 3
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
+ if black_cols[0].shape[0] < 2:
+ print('No Image data blocks found.')
+ return None
+
+ data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
+ data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
+
+ data_block_lower = xor_block(data_block_lower)
+ data_block_upper = xor_block(data_block_upper)
+
+ data_block = (data_block_upper << 4) | (data_block_lower)
+ data_block = data_block.flatten().tobytes()
+
+ data = zlib.decompress(data_block)
+ return json.loads(data, cls=EmbeddingDecoder)
+
+
+def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
+ from math import cos
+
+ image = srcimage.copy()
+ fontsize = 32
+ if textfont is None:
+ try:
+ textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
+ textfont = opts.font or Roboto
+ except Exception:
+ textfont = Roboto
+
+ factor = 1.5
+ gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
+ for y in range(image.size[1]):
+ mag = 1-cos(y/image.size[1]*factor)
+ mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
+ gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
+ image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
+
+ draw = ImageDraw.Draw(image)
+
+ font = ImageFont.truetype(textfont, fontsize)
+ padding = 10
+
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
+ font = ImageFont.truetype(textfont, fontsize)
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
+
+ _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
+ fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
+ fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
+ fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+
+ font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
+
+ draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
+
+ return image
+
+
+if __name__ == '__main__':
+
+ testEmbed = Image.open('test_embedding.png')
+ data = extract_image_data_embed(testEmbed)
+ assert data is not None
+
+ data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
+ assert data is not None
+
+ image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
+ cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
+
+ test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
+
+ embedded_image = insert_image_data_embed(cap_image, test_embed)
+
+ retrived_embed = extract_image_data_embed(embedded_image)
+
+ assert str(retrived_embed) == str(test_embed)
+
+ embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
+
+ assert embedded_image == embedded_image2
+
+ g = lcg()
+ shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
+
+ reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
+ 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
+ 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
+ 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
+ 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
+ 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
+ 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
+ 204, 86, 73, 222, 44, 198, 118, 240, 97]
+
+ assert shared_random == reference_random
+
+ hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
+
+ assert 12731374 == hunna_kay_random_sum
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..2062726a
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,69 @@
+import tqdm
+
+
+class LearnScheduleIterator:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..33eaddb6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,16 +1,46 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
+import time
from modules import shared, images
+from modules.shared import opts, cmd_opts
+if cmd_opts.deepdanbooru:
+ import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
+ try:
+ if process_caption:
+ shared.interrogator.load()
+
+ if process_caption_deepbooru:
+ db_opts = deepbooru.create_deepbooru_opts()
+ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
+
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
+
+ finally:
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
+
+ if process_caption_deepbooru:
+ deepbooru.release_process()
+
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
@@ -21,84 +51,97 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- if process_caption:
- shared.interrogator.load()
+ def save_pic_with_caption(image, index, existing_caption=None):
+ caption = ""
- def save_pic_with_caption(image, index):
if process_caption:
- caption = "-" + shared.interrogator.generate_caption(image)
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- else:
- caption = filename
- caption = os.path.splitext(caption)[0]
- caption = os.path.basename(caption)
+ caption += shared.interrogator.generate_caption(image)
+
+ if process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = filename
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{subindex[0]}-{filename_part}"
+ image.save(os.path.join(dst, f"{basename}.png"))
+
+ if preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
+ if len(caption) > 0:
+ with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
- image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
- def save_pic(image, index):
- save_pic_with_caption(image, index)
+ def save_pic(image, index, existing_caption=None):
+ save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+ save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
+
+ def split_pic(image, inverse_xy):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
+
+ existing_caption = None
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
if shared.state.interrupted:
break
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
-
- if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
-
- top = img.crop((0, 0, size, size))
- save_pic(top, index)
-
- bot = img.crop((0, img.height - size, size, img.height))
- save_pic(bot, index)
- elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
-
- left = img.crop((0, 0, size, size))
- save_pic(left, index)
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
- right = img.crop((img.width - size, 0, img.width, size))
- save_pic(right, index)
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy):
+ save_pic(splitted, index, existing_caption=existing_caption)
else:
- img = images.resize_image(1, img, size, size)
- save_pic(img, index)
+ img = images.resize_image(1, img, width, height)
+ save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
-def sanitize_caption(base_path, original_caption, suffix):
- operating_system = platform.system().lower()
- if (operating_system == "windows"):
- invalid_path_characters = "\\/:*?\"<>|"
- max_path_length = 259
- else:
- invalid_path_characters = "/" #linux/macos
- max_path_length = 1023
- caption = original_caption
- for invalid_character in invalid_path_characters:
- caption = caption.replace(invalid_character, "")
- fixed_path_length = len(base_path) + len(suffix)
- if fixed_path_length + len(caption) <= max_path_length:
- return caption
- caption_tokens = caption.split()
- new_caption = ""
- for token in caption_tokens:
- last_caption = new_caption
- new_caption = new_caption + token + " "
- if (len(new_caption) + fixed_path_length - 1 > max_path_length):
- break
- print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
- return last_caption.strip()
diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png
new file mode 100644
index 00000000..07e2d9af
--- /dev/null
+++ b/modules/textual_inversion/test_embedding.png
Binary files differ
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..529ed3e2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,11 +6,17 @@ import torch
import tqdm
import html
import datetime
+import csv
+from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
+ insert_image_data_embed, extract_image_data_embed,
+ caption_image_overlay)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -80,7 +86,18 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
+ else:
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -120,6 +137,7 @@ class EmbeddingDatabase:
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ print("Embeddings:", ', '.join(self.word_embeddings.keys()))
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -135,7 +153,7 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@@ -147,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
@@ -156,7 +175,33 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ epoch = step // epoch_len
+ epoch_step = step - epoch * epoch_len
+
+ csv_writer.writerow({
+ "step": step + 1,
+ "epoch": epoch + 1,
+ "epoch_step": epoch_step + 1,
+ **values,
+ })
+
+
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -178,43 +223,51 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
else:
images_dir = None
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
+ embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text) in pbar:
+ for i, entries in pbar:
embedding.step = i + ititial_step
- if embedding.step > steps:
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- c = cond_model([text])
-
- x = x.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
@@ -223,30 +276,83 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=text,
- steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
+ do_not_reload_embeddings=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
+
image.save(last_saved_image)
- last_saved_image += f", prompt: {text}"
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
@@ -254,7 +360,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
@@ -268,4 +374,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.save(filename)
return embedding, filename
-
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index f19ac5e0..e712284d 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -23,6 +23,8 @@ def preprocess(*args):
def train_embedding(*args):
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+
try:
sd_hijack.undo_optimizations()