aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/preprocess.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/preprocess.py')
-rw-r--r--modules/textual_inversion/preprocess.py92
1 files changed, 40 insertions, 52 deletions
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 113cecf1..3047bede 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -10,7 +10,28 @@ from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
+
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ try:
+ if process_caption:
+ shared.interrogator.load()
+
+ if process_caption_deepbooru:
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
+
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+
+ finally:
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
+
+ if process_caption_deepbooru:
+ deepbooru.release_process()
+
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -25,30 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- if process_caption:
- shared.interrogator.load()
-
- if process_caption_deepbooru:
- deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha)
-
def save_pic_with_caption(image, index):
+ caption = ""
+
if process_caption:
- caption = "-" + shared.interrogator.generate_caption(image)
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- elif process_caption_deepbooru:
- shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process_queue.put(image)
- while shared.deepbooru_process_return["value"] == -1:
- time.sleep(0.2)
- caption = "-" + shared.deepbooru_process_return["value"]
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- shared.deepbooru_process_return["value"] = -1
- else:
- caption = filename
- caption = os.path.splitext(caption)[0]
- caption = os.path.basename(caption)
+ caption += shared.interrogator.generate_caption(image)
+
+ if process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = filename
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{subindex[0]}-{filename_part}"
+ image.save(os.path.join(dst, f"{basename}.png"))
+
+ if len(caption) > 0:
+ with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
- image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
def save_pic(image, index):
@@ -93,34 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
save_pic(img, index)
shared.state.nextjob()
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
- if process_caption_deepbooru:
- deepbooru.release_process()
-
-
-def sanitize_caption(base_path, original_caption, suffix):
- operating_system = platform.system().lower()
- if (operating_system == "windows"):
- invalid_path_characters = "\\/:*?\"<>|"
- max_path_length = 259
- else:
- invalid_path_characters = "/" #linux/macos
- max_path_length = 1023
- caption = original_caption
- for invalid_character in invalid_path_characters:
- caption = caption.replace(invalid_character, "")
- fixed_path_length = len(base_path) + len(suffix)
- if fixed_path_length + len(caption) <= max_path_length:
- return caption
- caption_tokens = caption.split()
- new_caption = ""
- for token in caption_tokens:
- last_caption = new_caption
- new_caption = new_caption + token + " "
- if (len(new_caption) + fixed_path_length - 1 > max_path_length):
- break
- print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
- return last_caption.strip()