aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e4f339b8..21596e78 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -8,7 +8,7 @@ import html
import datetime
from PIL import Image,PngImagePlugin
-from ..images import captionImge
+from ..images import captionImageOverlay
import numpy as np
import base64
import json
@@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
else:
images_dir = None
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.current_image = image
- if save_image_with_stored_embedding:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data))
- pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
+ title = "<{}>".format(data.get('name','???'))
checkpoint = sd_models.select_checkpoint()
- post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash,
- embedding.step))]
- captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
- captioned_image.save(last_saved_image, "PNG", pnginfo=info)
- else:
- image.save(last_saved_image)
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '[{}]'.format(embedding.step)
+
+ captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+
+ image.save(last_saved_image)
last_saved_image += f", prompt: {text}"