aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/images.py77
-rw-r--r--modules/textual_inversion/textual_inversion.py61
-rw-r--r--modules/ui.py4
3 files changed, 140 insertions, 2 deletions
diff --git a/modules/images.py b/modules/images.py
index c0a90676..e62eec8e 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -463,3 +463,80 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn = None
return fullfn, txt_fullfn
+
+def addCaptionLines(lines,image,initialx,textfont):
+ draw = ImageDraw.Draw(image)
+ hstart =initialx
+ for fill,line in lines:
+ fontSize = 32
+ font = ImageFont.truetype(textfont, fontSize)
+ _,_,w, h = draw.textbbox((0,0),line,font=font)
+ fontSize = min( int(fontSize * ((image.size[0]-35)/w) ), 28)
+ font = ImageFont.truetype(textfont, fontSize)
+ _,_,w,h = draw.textbbox((0,0),line,font=font)
+ draw.text(((image.size[0]-w)/2,hstart), line, font=font, fill=fill)
+ hstart += h
+ return hstart
+
+def captionImge(image,prelines,postlines,background=(51, 51, 51),font=None):
+ if font is None:
+ try:
+ font = ImageFont.truetype(opts.font or Roboto, fontsize)
+ font = opts.font or Roboto
+ except Exception:
+ font = Roboto
+
+ sampleImage = image
+ background = Image.new("RGBA", (sampleImage.size[0],sampleImage.size[1]+1024), background)
+ hoffset = addCaptionLines(prelines,background,5,font)+16
+ background.paste(sampleImage,(0,hoffset))
+ hoffset = hoffset+sampleImage.size[1]+8
+ hoffset = addCaptionLines(postlines,background,hoffset,font)
+ background = background.crop((0,0,sampleImage.size[0],hoffset+8))
+ return background
+
+def captionImageOverlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None):
+ from math import cos
+
+ image = srcimage.copy()
+
+ if textfont is None:
+ try:
+ textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
+ textfont = opts.font or Roboto
+ except Exception:
+ textfont = Roboto
+
+ factor = 1.5
+ gradient = Image.new('RGBA', (1,image.size[1]), color=(0,0,0,0))
+ for y in range(image.size[1]):
+ mag = 1-cos(y/image.size[1]*factor)
+ mag = max(mag,1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
+ gradient.putpixel((0, y), (0,0,0,int(mag*255)))
+ image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
+
+ draw = ImageDraw.Draw(image)
+ fontSize = 32
+ font = ImageFont.truetype(textfont, fontSize)
+ padding = 10
+
+ _,_,w, h = draw.textbbox((0,0),title,font=font)
+ fontSize = min( int(fontSize * (((image.size[0]*0.75)-(padding*4))/w) ), 72)
+ font = ImageFont.truetype(textfont, fontSize)
+ _,_,w,h = draw.textbbox((0,0),title,font=font)
+ draw.text((padding,padding), title, anchor='lt', font=font, fill=(255,255,255,230))
+
+ _,_,w, h = draw.textbbox((0,0),footerLeft,font=font)
+ fontSizeleft = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72)
+ _,_,w, h = draw.textbbox((0,0),footerMid,font=font)
+ fontSizemid = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72)
+ _,_,w, h = draw.textbbox((0,0),footerRight,font=font)
+ fontSizeright = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72)
+
+ font = ImageFont.truetype(textfont, min(fontSizeleft,fontSizemid,fontSizeright))
+
+ draw.text((padding,image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255,255,255,230))
+ draw.text((image.size[0]/2,image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255,255,255,230))
+ draw.text((image.size[0]-padding,image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255,255,255,230))
+
+ return image
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5965c5a0..7a24192e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,10 +7,36 @@ import tqdm
import html
import datetime
+from PIL import Image,PngImagePlugin
+from ..images import captionImageOverlay
+import numpy as np
+import base64
+import json
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, o)
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+ def object_hook(self, d):
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
+ return d
+
+def embeddingToB64(data):
+ d = json.dumps(data,cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+def embeddingFromB64(data):
+ d = base64.b64decode(data)
+ return json.loads(d,cls=EmbeddingDecoder)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -80,7 +106,15 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if filename.upper().endswith('.PNG'):
+ embed_image = Image.open(path)
+ if 'sd-ti-embedding' in embed_image.text:
+ data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name',name)
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -178,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
else:
images_dir = None
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -252,6 +292,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
image = processed.images[0]
shared.state.current_image = image
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embeddingToB64(data))
+
+ title = "<{}>".format(data.get('name','???'))
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}'.format(embedding.step)
+
+ captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
diff --git a/modules/ui.py b/modules/ui.py
index 8c06ad7c..0f6427a6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1057,7 +1057,8 @@ def create_ui(wrap_gradio_gpu_call):
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
-
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
+
with gr.Row():
with gr.Column(scale=2):
gr.HTML(value="")
@@ -1124,6 +1125,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every,
save_embedding_every,
template_file,
+ save_image_with_stored_embedding,
],
outputs=[
ti_output,