| import json | |
| from torchvision.datasets import ImageFolder | |
| import torch | |
| import os | |
| from PIL import Image | |
| import collections | |
| import torchvision.transforms as transforms | |
| from label_str_to_imagenet_classes import label_str_to_imagenet_classes | |
| torch.manual_seed(0) | |
| ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) | |
| normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5]) | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize, | |
| ]) | |
| class RobustnessDataset(ImageFolder): | |
| def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False): | |
| self._isV2 = isV2 | |
| self._isSI = isSI | |
| self._folder = folder | |
| self._imagenet_path = imagenet_path | |
| with open(imagenet_classes_path, 'r') as f: | |
| self._imagenet_classes = json.load(f) | |
| self._all_images = [] | |
| base_dir = os.path.join(self._imagenet_path, folder) | |
| for i, file in enumerate(os.listdir(base_dir)): | |
| self._all_images.append(ImageItem(file, folder)) | |
| def __getitem__(self, item): | |
| image_item = self._all_images[item] | |
| image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name) | |
| image = Image.open(image_path) | |
| image = image.convert('RGB') | |
| image = transform(image) | |
| if self._isV2: | |
| class_name = int(image_item.tag) | |
| elif self._isSI: | |
| class_name = int(label_str_to_imagenet_classes[image_item.tag]) | |
| else: | |
| class_name = int(self._imagenet_classes[image_item.tag]) | |
| return image, class_name | |
| def __len__(self): | |
| return len(self._all_images) | |
| def get_classname(self): | |
| if self._isV2: | |
| class_name = int(self._folder) | |
| elif self._isSI: | |
| class_name = int(label_str_to_imagenet_classes[self._folder]) | |
| else: | |
| class_name = int(self._imagenet_classes[self._folder]) | |
| return class_name |