| import os | |
| import torch | |
| import torch.utils.data as data | |
| import numpy as np | |
| from PIL import Image | |
| import h5py | |
| __all__ = ['ImagenetResults'] | |
| class Imagenet_Segmentation(data.Dataset): | |
| CLASSES = 2 | |
| def __init__(self, | |
| path, | |
| transform=None, | |
| target_transform=None): | |
| self.path = path | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| self.h5py = None | |
| tmp = h5py.File(path, 'r') | |
| self.data_length = len(tmp['/value/img']) | |
| tmp.close() | |
| del tmp | |
| def __getitem__(self, index): | |
| if self.h5py is None: | |
| self.h5py = h5py.File(self.path, 'r') | |
| img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) | |
| target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) | |
| img = Image.fromarray(img).convert('RGB') | |
| target = Image.fromarray(target) | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.target_transform is not None: | |
| target = np.array(self.target_transform(target)).astype('int32') | |
| target = torch.from_numpy(target).long() | |
| return img, target | |
| def __len__(self): | |
| return self.data_length | |
| class ImagenetResults(data.Dataset): | |
| def __init__(self, path): | |
| super(ImagenetResults, self).__init__() | |
| self.path = os.path.join(path, 'results.hdf5') | |
| self.data = None | |
| print('Reading dataset length...') | |
| with h5py.File(self.path, 'r') as f: | |
| self.data_length = len(f['/image']) | |
| def __len__(self): | |
| return self.data_length | |
| def __getitem__(self, item): | |
| if self.data is None: | |
| self.data = h5py.File(self.path, 'r') | |
| image = torch.tensor(self.data['image'][item]) | |
| vis = torch.tensor(self.data['vis'][item]) | |
| target = torch.tensor(self.data['target'][item]).long() | |
| return image, vis, target | |