import logging from typing import Optional, Union from pytorch_ie import PieDataModule from pytorch_ie.core.taskmodule import IterableTaskEncodingDataset, TaskEncodingDataset from torch.utils.data import DataLoader, Sampler from .components.sampler import ImbalancedDatasetSampler logger = logging.getLogger(__name__) class PieDataModuleWithSampler(PieDataModule): def __init__( self, train_sampler: Optional[str] = None, dont_shuffle_train: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) self.train_sampler_name = train_sampler self.dont_shuffle_train = dont_shuffle_train def get_train_sampler( self, dataset: Union[TaskEncodingDataset, IterableTaskEncodingDataset], ) -> Optional[Sampler]: if self.train_sampler_name is None: return None elif self.train_sampler_name == "imbalanced_dataset": # for now, this work only with targets that have a single entry return ImbalancedDatasetSampler( dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds] ) else: raise ValueError(f"unknown sampler name: {self.train_sampler_name}") def train_dataloader(self) -> DataLoader: ds = self.data_split(self.train_split) sampler = self.get_train_sampler(dataset=ds) # don't shuffle if we explicitly set dont_shuffle_train, # streamed datasets or if we use a sampler or shuffle = not ( self.dont_shuffle_train or isinstance(ds, IterableTaskEncodingDataset) or sampler is not None ) if not shuffle: logger.warning("not shuffling train dataloader") return DataLoader( dataset=ds, sampler=sampler, collate_fn=self.taskmodule.collate, shuffle=shuffle, **self.dataloader_kwargs, )