From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Fri, 14 Oct 2022 10:56:41 +0200 Subject: init --- modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/textual_inversion.py | 35 ++++++++++++++++++-------- 2 files changed, 26 insertions(+), 11 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 67e90afe..59b2b021 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -48,7 +48,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) except Exception: continue diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fa0e33a2..b12a8e6d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -172,7 +172,15 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, + create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, + preview_image_prompt, batch_size=1, + gradient_accumulation=1): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -204,7 +212,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, + height=training_height, + repeats=shared.opts.training_image_repeats_per_epoch, + placeholder_token=embedding_name, model=shared.sd_model, + device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -223,7 +235,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step) for i, entry in pbar: embedding.step = i + ititial_step @@ -235,17 +247,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini break with torch.autocast("cuda"): - c = cond_model([entry.cond_text]) + c = cond_model([e.cond_text for e in entry]) + + x = torch.stack([e.latent for e in entry]).to(devices.device) + loss = shared.sd_model(x, c)[0] - x = entry.latent.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] del x losses[embedding.step % losses.shape[0]] = loss.item() - optimizer.zero_grad() loss.backward() - optimizer.step() + if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step): + optimizer.step() + optimizer.zero_grad() + epoch_num = embedding.step // len(ds) epoch_step = embedding.step - (epoch_num * len(ds)) + 1 @@ -259,7 +274,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -305,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entry[-1].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

-- cgit v1.2.1 From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 17:53:56 +0200 Subject: ui fix, re organization of the code --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 68ceffe3..23bb4b6a 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -49,7 +49,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: continue -- cgit v1.2.1 From 0087079c2d487b67b06ffc30f36ce486a74e6318 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:10:59 +0100 Subject: allow overwrite old embedding --- modules/textual_inversion/textual_inversion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3be69562..5776778b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -153,7 +153,7 @@ class EmbeddingDatabase: return None, None -def create_embedding(name, num_vectors_per_token, init_text='*'): +def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings @@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" embedding = Embedding(vec, name) embedding.step = 0 -- cgit v1.2.1 From c3835ec85cbb44fa3c46fa871c622b6fee235c89 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:24:24 +0100 Subject: pass overwrite old flag --- modules/textual_inversion/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 36881e7a..e712284d 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess from modules import sd_hijack, shared -def create_embedding(name, initialization_text, nvpt): - filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) +def create_embedding(name, initialization_text, nvpt, overwrite_old): + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() -- cgit v1.2.1 From fbcce66601994f6ed370db36d9c238840fed6bd2 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:46:54 +0100 Subject: add existing caption file handling --- modules/textual_inversion/preprocess.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..5c43fe13 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) - def save_pic_with_caption(image, index): + def save_pic_with_caption(image, index, existing_caption=None): caption = "" if process_caption: @@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro basename = f"{index:05}-{subindex[0]}-{filename_part}" image.save(os.path.join(dst, f"{basename}.png")) + if preprocess_txt_action == 'prepend' and existing_caption: + caption = existing_caption + ' ' + caption + elif preprocess_txt_action == 'append' and existing_caption: + caption = caption + ' ' + existing_caption + elif preprocess_txt_action == 'copy' and existing_caption: + caption = existing_caption + + caption = caption.strip() + if len(caption) > 0: with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) subindex[0] += 1 - def save_pic(image, index): + def save_pic(image, index, existing_caption=None): save_pic_with_caption(image, index) if process_flip: - save_pic_with_caption(ImageOps.mirror(image), index) + save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] @@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro except Exception: continue + existing_caption = None + + try: + existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read() + except Exception as e: + print(e) + if shared.state.interrupted: break @@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro img = img.resize((width, height * img.height // img.width)) top = img.crop((0, 0, width, height)) - save_pic(top, index) + save_pic(top, index, existing_caption=existing_caption) bot = img.crop((0, img.height - height, width, img.height)) - save_pic(bot, index) + save_pic(bot, index, existing_caption=existing_caption) elif process_split and is_wide: img = img.resize((width * img.width // img.height, height)) left = img.crop((0, 0, width, height)) - save_pic(left, index) + save_pic(left, index, existing_caption=existing_caption) right = img.crop((img.width - width, 0, img.width, height)) - save_pic(right, index) + save_pic(right, index, existing_caption=existing_caption) else: img = images.resize_image(1, img, width, height) - save_pic(img, index) + save_pic(img, index, existing_caption=existing_caption) shared.state.nextjob() -- cgit v1.2.1 From 9b65c4ecf4f8eb6187ee721918adebe68e9bc631 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:49:23 +0100 Subject: pass preprocess_txt_action param --- modules/textual_inversion/preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 5c43fe13..3713bc89 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): try: if process_caption: shared.interrogator.load() @@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru) finally: @@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): width = process_width height = process_height src = os.path.abspath(process_src) -- cgit v1.2.1 From 858462f719c22ca9f24b94a41699653c34b5f4fb Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 02:57:18 +0100 Subject: do caption copy for both flips --- modules/textual_inversion/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3713bc89..6bba3852 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -82,7 +82,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre subindex[0] += 1 def save_pic(image, index, existing_caption=None): - save_pic_with_caption(image, index) + save_pic_with_caption(image, index, existing_caption=existing_caption) if process_flip: save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) -- cgit v1.2.1 From 9681419e422515e42444e0174355b760645a846f Mon Sep 17 00:00:00 2001 From: Milly Date: Thu, 20 Oct 2022 16:53:46 +0900 Subject: train: fixed preprocess image ratio --- modules/textual_inversion/preprocess.py | 54 +++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 19 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..2743bdeb 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,6 @@ import os from PIL import Image, ImageOps +import math import platform import sys import tqdm @@ -38,6 +39,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) + split_threshold = 0.5 + overlap_ratio = 0.2 assert src != dst, 'same directory specified as source and destination' @@ -78,6 +81,29 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if process_flip: save_pic_with_caption(ImageOps.mirror(image), index) + def split_pic(image, inverse_xy): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) @@ -89,26 +115,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if shared.state.interrupted: break - ratio = img.height / img.width - is_tall = ratio > 1.35 - is_wide = ratio < 1 / 1.35 - - if process_split and is_tall: - img = img.resize((width, height * img.height // img.width)) - - top = img.crop((0, 0, width, height)) - save_pic(top, index) - - bot = img.crop((0, img.height - height, width, img.height)) - save_pic(bot, index) - elif process_split and is_wide: - img = img.resize((width * img.width // img.height, height)) - - left = img.crop((0, 0, width, height)) - save_pic(left, index) + if img.height > img.width: + ratio = (img.width * height) / (img.height * width) + inverse_xy = False + else: + ratio = (img.height * width) / (img.width * height) + inverse_xy = True - right = img.crop((img.width - width, 0, img.width, height)) - save_pic(right, index) + if process_split and ratio < 1.0 and ratio <= split_threshold: + for splitted in split_pic(img, inverse_xy): + save_pic(splitted, index) else: img = images.resize_image(1, img, width, height) save_pic(img, index) -- cgit v1.2.1 From 85dd62c4c7635b8e21a75f140d093036069e97a1 Mon Sep 17 00:00:00 2001 From: Milly Date: Thu, 20 Oct 2022 22:56:45 +0900 Subject: train: ui: added `Split image threshold` and `Split image overlap ratio` to preprocess --- modules/textual_inversion/preprocess.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 2743bdeb..c8df8aa0 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): try: if process_caption: shared.interrogator.load() @@ -22,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio) finally: @@ -34,13 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): width = process_width height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) - split_threshold = 0.5 - overlap_ratio = 0.2 + split_threshold = max(0.0, min(1.0, split_threshold)) + overlap_ratio = max(0.0, min(0.9, overlap_ratio)) assert src != dst, 'same directory specified as source and destination' -- cgit v1.2.1 From b69c37d25e4ffc56e8f8c247fa2c38b4648cefb7 Mon Sep 17 00:00:00 2001 From: guaneec Date: Thu, 20 Oct 2022 22:21:12 +0800 Subject: Allow datasets with only 1 image in TI --- modules/textual_inversion/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 23bb4b6a..5b1c5002 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -83,7 +83,7 @@ class PersonalizedBase(Dataset): self.dataset.append(entry) - assert len(self.dataset) > 1, "No images have been found in the dataset." + assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset) * repeats // batch_size self.initial_indexes = np.arange(len(self.dataset)) @@ -91,7 +91,7 @@ class PersonalizedBase(Dataset): self.shuffle() def shuffle(self): - self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] def create_text(self, filename_text): text = random.choice(self.lines) -- cgit v1.2.1 From d0ea471b0cdaede163c6e7f6fae8535f5c3cd226 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 21 Oct 2022 14:04:41 +0100 Subject: Use opts in textual_inversion image_embedding.py for dynamic fonts --- modules/textual_inversion/image_embedding.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 898ce3b3..c50b1e7b 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -5,6 +5,7 @@ import zlib from PIL import Image, PngImagePlugin, ImageDraw, ImageFont from fonts.ttf import Roboto import torch +from modules.shared import opts class EmbeddingEncoder(json.JSONEncoder): -- cgit v1.2.1 From 306e2ff6ab8f4c7e94ab55f4f08ab8f94d73d287 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 21 Oct 2022 14:47:21 +0100 Subject: Update image_embedding.py --- modules/textual_inversion/image_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index c50b1e7b..ea653806 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -134,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t from math import cos image = srcimage.copy() - + fontsize = 32 if textfont is None: try: textfont = ImageFont.truetype(opts.font or Roboto, fontsize) @@ -151,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) draw = ImageDraw.Draw(image) - fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) padding = 10 -- cgit v1.2.1 From f49c08ea566385db339c6628f65c3a121033f67c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 18:46:02 +0300 Subject: prevent error spam when processing images without txt files for captions --- modules/textual_inversion/preprocess.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 17e4ddc1..33eaddb6 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -122,11 +122,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre continue existing_caption = None - - try: - existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read() - except Exception as e: - print(e) + existing_caption_filename = os.path.splitext(filename)[0] + '.txt' + if os.path.exists(existing_caption_filename): + with open(existing_caption_filename, 'r', encoding="utf8") as file: + existing_caption = file.read() if shared.state.interrupted: break -- cgit v1.2.1