aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-10-04 01:04:19 -0400
committerGitHub <noreply@github.com>2022-10-04 01:04:19 -0400
commitb88e4ea7d6536ddc814a9101a258f87b06230019 (patch)
tree913bdda300b7082aa61429689d0a730cecda1525 /modules/textual_inversion/dataset.py
parentbdaa36c84470adbdce3e98c01a69af5e95adfb02 (diff)
parent2865ef4b9ab16d56326cc805541bebcf01d099bc (diff)
Merge branch 'master' into master
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py78
1 files changed, 78 insertions, 0 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
new file mode 100644
index 00000000..e8394ff6
--- /dev/null
+++ b/modules/textual_inversion/dataset.py
@@ -0,0 +1,78 @@
+import os
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+import random
+import tqdm
+from modules import devices
+
+
+class PersonalizedBase(Dataset):
+ def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+
+ self.placeholder_token = placeholder_token
+
+ self.size = size
+ self.width = width
+ self.height = height
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ self.dataset = []
+
+ with open(template_file, "r") as file:
+ lines = [x.strip() for x in file.readlines()]
+
+ self.lines = lines
+
+ assert data_root, 'dataset directory not specified'
+
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+ print("Preparing dataset...")
+ for path in tqdm.tqdm(self.image_paths):
+ image = Image.open(path)
+ image = image.convert('RGB')
+ image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+
+ filename = os.path.basename(path)
+ filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
+ filename_tokens = [token for token in filename_tokens if token.isalpha()]
+
+ npimage = np.array(image).astype(np.uint8)
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
+
+ torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
+ torchdata = torch.moveaxis(torchdata, 2, 0)
+
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
+ init_latent = init_latent.to(devices.cpu)
+
+ self.dataset.append((init_latent, filename_tokens))
+
+ self.length = len(self.dataset) * repeats
+
+ self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.indexes = None
+ self.shuffle()
+
+ def shuffle(self):
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, i):
+ if i % len(self.dataset) == 0:
+ self.shuffle()
+
+ index = self.indexes[i % len(self.indexes)]
+ x, filename_tokens = self.dataset[index]
+
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+
+ return x, text