import os import numpy as np import torch import torchvision.datasets as datasets class ImageFolderWithFilename(datasets.ImageFolder): def __getitem__(self, index: int): """ Args: index (int): Index Returns: tuple: (sample, target, filename). """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) filename = path.split(os.path.sep)[-2:] filename = os.path.join(*filename) return sample, target, filename class CachedFolder(datasets.DatasetFolder): def __init__( self, root: str, ): super().__init__( root, loader=None, extensions=(".npz",), ) def __getitem__(self, index: int): """ Args: index (int): Index Returns: tuple: (moments, target). """ path, target = self.samples[index] data = np.load(path) if torch.rand(1) < 0.5: # randomly hflip moments = data['moments'] else: moments = data['moments_flip'] return moments, target