aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py156
1 files changed, 154 insertions, 2 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d6977950..8c66aeb5 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,10 +7,124 @@ import tqdm
import html
import datetime
+from PIL import Image,PngImagePlugin,ImageDraw
+from ..images import captionImageOverlay
+import numpy as np
+import base64
+import json
+import zlib
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, obj)
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+ def object_hook(self, d):
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
+ return d
+
+def embeddingToB64(data):
+ d = json.dumps(data,cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+def embeddingFromB64(data):
+ d = base64.b64decode(data)
+ return json.loads(d,cls=EmbeddingDecoder)
+
+def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
+ while True:
+ seed = (a * seed + c) % m
+ yield seed
+
+def xorBlock(block):
+ g = lcg()
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
+ return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F)
+
+def styleBlock(block,sequence):
+ im = Image.new('RGB',(block.shape[1],block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i=0
+ for x in range(-6,im.size[0],8):
+ for yi,y in enumerate(range(-6,im.size[1],8)):
+ offset=0
+ if yi%2==0:
+ offset=4
+ shade = sequence[i%len(sequence)]
+ i+=1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) )
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+ return block ^ fg
+
+def insertImageDataEmbed(image,data):
+ d = 3
+ data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
+ dnp = np.frombuffer(data_compressed,np.uint8).copy()
+ dnphigh = dnp >> 4
+ dnplow = dnp & 0x0F
+
+ h = image.size[1]
+ next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h))
+ next_size = next_size + ((h*d)-(next_size%(h*d)))
+
+ dnplow.resize(next_size)
+ dnplow = dnplow.reshape((h,-1,d))
+
+ dnphigh.resize(next_size)
+ dnphigh = dnphigh.reshape((h,-1,d))
+
+ edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8)
+
+ dnplow = styleBlock(dnplow,sequence=edgeStyleWeights)
+ dnplow = xorBlock(dnplow)
+ dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1])
+ dnphigh = xorBlock(dnphigh)
+
+ imlow = Image.fromarray(dnplow,mode='RGB')
+ imhigh = Image.fromarray(dnphigh,mode='RGB')
+
+ background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0))
+ background.paste(imlow,(0,0))
+ background.paste(image,(imlow.size[0]+1,0))
+ background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0))
+
+ return background
+
+def crop_black(img,tol=0):
+ mask = (img>tol).all(2)
+ mask0,mask1 = mask.any(0),mask.any(1)
+ col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
+ row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end,col_start:col_end]
+
+def extractImageDataEmbed(image):
+ d=3
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
+ blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
+ if blackCols[0].shape[0] < 2:
+ print('No Image data blocks found.')
+ return None
+
+ dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8)
+ dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8)
+
+ dataBlocklower = xorBlock(dataBlocklower)
+ dataBlockupper = xorBlock(dataBlockupper)
+
+ dataBlock = (dataBlockupper << 4) | (dataBlocklower)
+ dataBlock = dataBlock.flatten().tobytes()
+ data = zlib.decompress(dataBlock)
+ return json.loads(data,cls=EmbeddingDecoder)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -80,7 +194,18 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if filename.upper().endswith('.PNG'):
+ embed_image = Image.open(path)
+ if 'sd-ti-embedding' in embed_image.text:
+ data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name',name)
+ else:
+ data = extractImageDataEmbed(embed_image)
+ name = data.get('name',name)
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -156,7 +281,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt)
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -178,6 +304,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
else:
images_dir = None
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -254,6 +386,26 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
image = processed.images[0]
shared.state.current_image = image
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embeddingToB64(data))
+
+ title = "<{}>".format(data.get('name','???'))
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}'.format(embedding.step)
+
+ captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
+ captioned_image = insertImageDataEmbed(captioned_image,data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"