File size: 1,967 Bytes
d868d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import logging
from typing import Optional, Union
from pytorch_ie import PieDataModule
from pytorch_ie.core.taskmodule import IterableTaskEncodingDataset, TaskEncodingDataset
from torch.utils.data import DataLoader, Sampler
from .components.sampler import ImbalancedDatasetSampler
logger = logging.getLogger(__name__)
class PieDataModuleWithSampler(PieDataModule):
def __init__(
self,
train_sampler: Optional[str] = None,
dont_shuffle_train: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.train_sampler_name = train_sampler
self.dont_shuffle_train = dont_shuffle_train
def get_train_sampler(
self,
dataset: Union[TaskEncodingDataset, IterableTaskEncodingDataset],
) -> Optional[Sampler]:
if self.train_sampler_name is None:
return None
elif self.train_sampler_name == "imbalanced_dataset":
# for now, this work only with targets that have a single entry
return ImbalancedDatasetSampler(
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
)
else:
raise ValueError(f"unknown sampler name: {self.train_sampler_name}")
def train_dataloader(self) -> DataLoader:
ds = self.data_split(self.train_split)
sampler = self.get_train_sampler(dataset=ds)
# don't shuffle if we explicitly set dont_shuffle_train,
# streamed datasets or if we use a sampler or
shuffle = not (
self.dont_shuffle_train
or isinstance(ds, IterableTaskEncodingDataset)
or sampler is not None
)
if not shuffle:
logger.warning("not shuffling train dataloader")
return DataLoader(
dataset=ds,
sampler=sampler,
collate_fn=self.taskmodule.collate,
shuffle=shuffle,
**self.dataloader_kwargs,
)
|