diff options
Diffstat (limited to 'ldm/data/base.py')
-rw-r--r-- | ldm/data/base.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/ldm/data/base.py b/ldm/data/base.py new file mode 100644 index 00000000..b196c2f7 --- /dev/null +++ b/ldm/data/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass
\ No newline at end of file |