aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-13 14:57:38 +0300
committerGitHub <noreply@github.com>2023-01-13 14:57:38 +0300
commit9cd7716753c5be47f76b8e5555cc3e7c0f17d34d (patch)
tree345be78dd1991b77fcf4519bc44097e975e0b0c4 /modules/textual_inversion/dataset.py
parent18f86e41f6f289042c075bff1498e620ab997b8c (diff)
parent544e7a233e994f379dd67df08f5f519290b10293 (diff)
Merge branch 'master' into tensorboard
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py183
1 files changed, 139 insertions, 44 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 23bb4b6a..d31963d4 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -3,35 +3,38 @@ import numpy as np
import PIL
import torch
from PIL import Image
-from torch.utils.data import Dataset
+from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
+from collections import defaultdict
+from random import shuffle, choices
import random
import tqdm
from modules import devices, shared
import re
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
- def __init__(self, filename=None, latent=None, filename_text=None):
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename
- self.latent = latent
self.filename_text = filename_text
- self.cond = None
- self.cond_text = None
+ self.latent_dist = latent_dist
+ self.latent_sample = latent_sample
+ self.cond = cond
+ self.cond_text = cond_text
+ self.pixel_values = pixel_values
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
- self.batch_size = batch_size
- self.width = width
- self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
@@ -42,14 +45,23 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
-
- cond_model = shared.sd_model.cond_stage_model
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+
+ self.shuffle_tags = shuffle_tags
+ self.tag_drop_out = tag_drop_out
+ groups = defaultdict(list)
+
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
+ if shared.state.interrupted:
+ raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB')
+ if not varsize:
+ image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
continue
@@ -69,53 +81,136 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
- torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
- torchdata = torch.moveaxis(torchdata, 2, 0)
-
- init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
- init_latent = init_latent.to(devices.cpu)
-
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
-
- if include_cond:
+ torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
+ latent_sample = None
+
+ with devices.autocast():
+ latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
+
+ if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+ latent_sampling_method = "once"
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ elif latent_sampling_method == "deterministic":
+ # Works only for DiagonalGaussianDistribution
+ latent_dist.std = 0
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
+ elif latent_sampling_method == "random":
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
+
+ if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
- entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
+ if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
+ with devices.autocast():
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
+ groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
-
- assert len(self.dataset) > 1, "No images have been found in the dataset."
- self.length = len(self.dataset) * repeats // batch_size
-
- self.initial_indexes = np.arange(len(self.dataset))
- self.indexes = None
- self.shuffle()
-
- def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ del torchdata
+ del latent_dist
+ del latent_sample
+
+ self.length = len(self.dataset)
+ self.groups = list(groups.values())
+ assert self.length > 0, "No images have been found in the dataset."
+ self.batch_size = min(batch_size, self.length)
+ self.gradient_step = min(gradient_step, self.length // self.batch_size)
+ self.latent_sampling_method = latent_sampling_method
+
+ if len(groups) > 1:
+ print("Buckets:")
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
+ print(f" {w}x{h}: {len(ids)}")
+ print()
def create_text(self, filename_text):
text = random.choice(self.lines)
+ tags = filename_text.split(',')
+ if self.tag_drop_out != 0:
+ tags = [t for t in tags if random.random() > self.tag_drop_out]
+ if self.shuffle_tags:
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
return self.length
def __getitem__(self, i):
- res = []
+ entry = self.dataset[i]
+ if self.tag_drop_out != 0 or self.shuffle_tags:
+ entry.cond_text = self.create_text(entry.filename_text)
+ if self.latent_sampling_method == "random":
+ entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
+ return entry
+
+
+class GroupedBatchSampler(Sampler):
+ def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ super().__init__(data_source)
+
+ n = len(data_source)
+ self.groups = data_source.groups
+ self.len = n_batch = n // batch_size
+ expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
+ self.base = [int(e) // batch_size for e in expected]
+ self.n_rand_batches = nrb = n_batch - sum(self.base)
+ self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
+ self.batch_size = batch_size
+
+ def __len__(self):
+ return self.len
+
+ def __iter__(self):
+ b = self.batch_size
+
+ for g in self.groups:
+ shuffle(g)
+
+ batches = []
+ for g in self.groups:
+ batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
+ for _ in range(self.n_rand_batches):
+ rand_group = choices(self.groups, self.probs)[0]
+ batches.append(choices(rand_group, k=b))
+
+ shuffle(batches)
+
+ yield from batches
+
+
+class PersonalizedDataLoader(DataLoader):
+ def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
+ if latent_sampling_method == "random":
+ self.collate_fn = collate_wrapper_random
+ else:
+ self.collate_fn = collate_wrapper
+
+
+class BatchLoader:
+ def __init__(self, data):
+ self.cond_text = [entry.cond_text for entry in data]
+ self.cond = [entry.cond for entry in data]
+ self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ #self.emb_index = [entry.emb_index for entry in data]
+ #print(self.latent_sample.device)
- for j in range(self.batch_size):
- position = i * self.batch_size + j
- if position % len(self.indexes) == 0:
- self.shuffle()
+ def pin_memory(self):
+ self.latent_sample = self.latent_sample.pin_memory()
+ return self
- index = self.indexes[position % len(self.indexes)]
- entry = self.dataset[index]
+def collate_wrapper(batch):
+ return BatchLoader(batch)
- if entry.cond is None:
- entry.cond_text = self.create_text(entry.filename_text)
+class BatchLoaderRandom(BatchLoader):
+ def __init__(self, data):
+ super().__init__(data)
- res.append(entry)
+ def pin_memory(self):
+ return self
- return res
+def collate_wrapper_random(batch):
+ return BatchLoaderRandom(batch) \ No newline at end of file