aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/dataset.py4
-rw-r--r--modules/textual_inversion/image_embedding.py5
-rw-r--r--modules/textual_inversion/preprocess.py40
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/textual_inversion/ui.py4
5 files changed, 39 insertions, 20 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 23bb4b6a..5b1c5002 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry)
- assert len(self.dataset) > 1, "No images have been found in the dataset."
+ assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text):
text = random.choice(self.lines)
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 898ce3b3..ea653806 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
+from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos
image = srcimage.copy()
-
+ fontsize = 32
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
- fontsize = 32
+
font = ImageFont.truetype(textfont, fontsize)
padding = 10
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..6bba3852 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
@@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
@@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- def save_pic_with_caption(image, index):
+ def save_pic_with_caption(image, index, existing_caption=None):
caption = ""
if process_caption:
@@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
+ if preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
- def save_pic(image, index):
- save_pic_with_caption(image, index)
+ def save_pic(image, index, existing_caption=None):
+ save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+ save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
@@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception:
continue
+ existing_caption = None
+
+ try:
+ existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
+ except Exception as e:
+ print(e)
+
if shared.state.interrupted:
break
@@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height))
- save_pic(top, index)
+ save_pic(top, index, existing_caption=existing_caption)
bot = img.crop((0, img.height - height, width, img.height))
- save_pic(bot, index)
+ save_pic(bot, index, existing_caption=existing_caption)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height))
- save_pic(left, index)
+ save_pic(left, index, existing_caption=existing_caption)
right = img.crop((img.width - width, 0, img.width, height))
- save_pic(right, index)
+ save_pic(right, index, existing_caption=existing_caption)
else:
img = images.resize_image(1, img, width, height)
- save_pic(img, index)
+ save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3be69562..529ed3e2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
@@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
optimizer.step()
+
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 36881e7a..e712284d 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()