aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py23
1 files changed, 19 insertions, 4 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 110c0e09..f470324a 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -138,9 +138,12 @@ class PersonalizedBase(Dataset):
return entry
class PersonalizedDataLoader(DataLoader):
- def __init__(self, *args, **kwargs):
- super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs)
- self.collate_fn = collate_wrapper
+ def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
+ super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ if latent_sampling_method == "random":
+ self.collate_fn = collate_wrapper_random
+ else:
+ self.collate_fn = collate_wrapper
class BatchLoader:
@@ -148,10 +151,22 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ #self.emb_index = [entry.emb_index for entry in data]
+ #print(self.latent_sample.device)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
- return BatchLoader(batch) \ No newline at end of file
+ return BatchLoader(batch)
+
+class BatchLoaderRandom(BatchLoader):
+ def __init__(self, data):
+ super().__init__(data)
+
+ def pin_memory(self):
+ return self
+
+def collate_wrapper_random(batch):
+ return BatchLoaderRandom(batch) \ No newline at end of file