aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordan <guaneec@gmail.com>2023-01-11 05:31:58 +0800
committerdan <guaneec@gmail.com>2023-01-11 05:31:58 +0800
commit6be644fa04ce1542f3a01804310cbbc0a4a91620 (patch)
tree98791617ab751d77aea24e3670dcb01eecc28fa7
parent50fb20cedc8dcbf64f86aed6d6e89595d655e638 (diff)
Enable batch_size>1 for mixed-sized training
-rw-r--r--modules/textual_inversion/dataset.py36
1 files changed, 32 insertions, 4 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index fa48708e..b47414f3 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -3,8 +3,10 @@ import numpy as np
import PIL
import torch
from PIL import Image
-from torch.utils.data import Dataset, DataLoader
+from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
+from collections import defaultdict
+from random import shuffle, choices
import random
import tqdm
@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
+ groups = defaultdict(list)
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
@@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
-
+ groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
self.length = len(self.dataset)
+ self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
@@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
+class GroupedBatchSampler(Sampler):
+ def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ n = len(data_source)
+ self.groups = data_source.groups
+ self.len = n_batch = n // batch_size
+ expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
+ self.base = [int(e) // batch_size for e in expected]
+ self.n_rand_batches = nrb = n_batch - sum(self.base)
+ self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
+ self.batch_size = batch_size
+ def __len__(self):
+ return self.len
+ def __iter__(self):
+ b = self.batch_size
+ for g in self.groups:
+ shuffle(g)
+ batches = []
+ for g in self.groups:
+ batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
+ for _ in range(self.n_rand_batches):
+ rand_group = choices(self.groups, self.probs)[0]
+ batches.append(choices(rand_group, k=b))
+ shuffle(batches)
+ yield from batches
+
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
- super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else: