aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/textual_inversion.py88
1 files changed, 65 insertions, 23 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 667a7cf2..95eebea7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -39,20 +39,59 @@ def embeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
-def appendImageDataFooter(image,data):
+def xorBlock(block):
+ return np.bitwise_xor(block.astype(np.uint8),
+ ((np.random.RandomState(0xDEADBEEF).random(block.shape)*255).astype(np.uint8)) & 0x0F )
+
+def styleBlock(block,sequence):
+ im = Image.new('RGB',(block.shape[1],block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i=0
+ for x in range(-6,im.size[0],8):
+ for yi,y in enumerate(range(-6,im.size[1],8)):
+ offset=0
+ if yi%2==0:
+ offset=4
+ shade = sequence[i%len(sequence)]
+ i+=1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) )
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+ return block ^ fg
+
+def insertImageDataEmbed(image,data):
d = 3
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
dnp = np.frombuffer(data_compressed,np.uint8).copy()
- w = image.size[0]
- next_size = dnp.shape[0] + (w-(dnp.shape[0]%w))
- next_size = next_size + ((w*d)-(next_size%(w*d)))
- dnp.resize(next_size)
- dnp = dnp.reshape((-1,w,d))
- print(dnp.shape)
- im = Image.fromarray(dnp,mode='RGB')
- background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0))
- background.paste(image,(0,0))
- background.paste(im,(0,image.size[1]+1))
+ dnphigh = dnp >> 4
+ dnplow = dnp & 0x0F
+
+ h = image.size[1]
+ next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h))
+ next_size = next_size + ((h*d)-(next_size%(h*d)))
+
+ dnplow.resize(next_size)
+ dnplow = dnplow.reshape((h,-1,d))
+
+ dnphigh.resize(next_size)
+ dnphigh = dnphigh.reshape((h,-1,d))
+
+ edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8)
+
+ dnplow = styleBlock(dnplow,sequence=edgeStyleWeights)
+ dnplow = xorBlock(dnplow)
+ dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1])
+ dnphigh = xorBlock(dnphigh)
+
+ imlow = Image.fromarray(dnplow,mode='RGB')
+ imhigh = Image.fromarray(dnphigh,mode='RGB')
+
+ background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0))
+ background.paste(imlow,(0,0))
+ background.paste(image,(imlow.size[0]+1,0))
+ background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0))
+
return background
def crop_black(img,tol=0):
@@ -62,19 +101,22 @@ def crop_black(img,tol=0):
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end,col_start:col_end]
-def extractImageDataFooter(image):
+def extractImageDataEmbed(image):
d=3
- outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
- lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
- if lastRow[0].shape[0] == 0:
- print('Image data block not found.')
+ outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
+ blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
+ if blackCols[0].shape[0] < 2:
+ print('No Image data blocks found.')
return None
- lastRow = lastRow[0]
-
- lastRow = lastRow.max()
- dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes()
- print(lastRow)
+ dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8)
+ dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8)
+
+ dataBlocklower = xorBlock(dataBlocklower)
+ dataBlockupper = xorBlock(dataBlockupper)
+
+ dataBlock = (dataBlockupper << 4) | (dataBlocklower)
+ dataBlock = dataBlock.flatten().tobytes()
data = zlib.decompress(dataBlock)
return json.loads(data,cls=EmbeddingDecoder)
@@ -154,7 +196,7 @@ class EmbeddingDatabase:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
else:
- data = extractImageDataFooter(embed_image)
+ data = extractImageDataEmbed(embed_image)
name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@@ -351,7 +393,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
- captioned_image = appendImageDataFooter(captioned_image,data)
+ captioned_image = insertImageDataEmbed(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)