aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/textual_inversion/textual_inversion.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f6316020..1b7f8906 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,9 +7,11 @@ import tqdm
import html
import datetime
-from PIL import Image, PngImagePlugin
+from PIL import Image,PngImagePlugin
+from ..images import captionImge
+import numpy as np
import base64
-from io import BytesIO
+import json
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -87,9 +89,9 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
- if 'sd-embedding' in embed_image.text:
- embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
- data = torch.load(BytesIO(embeddingData), map_location="cpu")
+ if 'sd-ti-embedding' in embed_image.text:
+ data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if save_image_with_stored_embedding:
info = PngImagePlugin.PngInfo()
- info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
- image.save(last_saved_image, "PNG", pnginfo=info)
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embeddingToB64(data))
+
+ pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
+
+ caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
+ caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
+ caption_stepcount = data.get('step',0)
+ caption_stepcount = caption_stepcount if caption_stepcount else 0
+
+ post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
+ caption_stepcount))]
+ captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
+ captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else:
image.save(last_saved_image)
-
-
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step