aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorDepFA <35278260+dfaker@users.noreply.github.com>2022-10-09 05:38:38 +0100
committerGitHub <noreply@github.com>2022-10-09 05:38:38 +0100
commit5841990b0df04906da7321beef6f7f7902b7d57b (patch)
tree80e405670fae2d4c4933a23d3a2f9a2f8269b201 /modules/textual_inversion/textual_inversion.py
parent050a6a798cec90ae2f881c2ddd3f0221e69907dc (diff)
Update textual_inversion.py
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py25
1 files changed, 22 insertions, 3 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..f6316020 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,6 +7,9 @@ import tqdm
import html
import datetime
+from PIL import Image, PngImagePlugin
+import base64
+from io import BytesIO
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@@ -80,7 +83,15 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if filename.upper().endswith('.PNG'):
+ embed_image = Image.open(path)
+ if 'sd-embedding' in embed_image.text:
+ embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
+ data = torch.load(BytesIO(embeddingData), map_location="cpu")
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -244,7 +255,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
image = processed.images[0]
shared.state.current_image = image
- image.save(last_saved_image)
+
+ if save_image_with_stored_embedding:
+ info = PngImagePlugin.PngInfo()
+ info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
+ image.save(last_saved_image, "PNG", pnginfo=info)
+ else:
+ image.save(last_saved_image)
+
+
last_saved_image += f", prompt: {text}"