aboutsummaryrefslogtreecommitdiff
path: root/ldm/data/base.py
diff options
context:
space:
mode:
authorZac Liu <liuguang@baai.ac.cn>2022-11-30 11:14:04 +0800
committerGitHub <noreply@github.com>2022-11-30 11:14:04 +0800
commita39a57cb1f5964d9af2b541f7b352576adeeac0f (patch)
treeebae98ea40ecc5b34497424bee19310e9fac4068 /ldm/data/base.py
parent4b3c5bc24bffdf429c463a465763b3077fe55eb8 (diff)
parent0831ab476c626eb796b609acf8771177692bfab7 (diff)
Merge pull request #1 from 920232796/master
Add AltDiffusion
Diffstat (limited to 'ldm/data/base.py')
-rw-r--r--ldm/data/base.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/ldm/data/base.py b/ldm/data/base.py
new file mode 100644
index 00000000..b196c2f7
--- /dev/null
+++ b/ldm/data/base.py
@@ -0,0 +1,23 @@
+from abc import abstractmethod
+from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
+
+
+class Txt2ImgIterableBaseDataset(IterableDataset):
+ '''
+ Define an interface to make the IterableDatasets for text2img data chainable
+ '''
+ def __init__(self, num_records=0, valid_ids=None, size=256):
+ super().__init__()
+ self.num_records = num_records
+ self.valid_ids = valid_ids
+ self.sample_ids = valid_ids
+ self.size = size
+
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
+
+ def __len__(self):
+ return self.num_records
+
+ @abstractmethod
+ def __iter__(self):
+ pass \ No newline at end of file