Spaces:
Runtime error
Runtime error
| import random | |
| import numpy as np | |
| import os.path | |
| from swapae.data.base_dataset import BaseDataset, get_transform | |
| import torchvision | |
| class CIFAR100Dataset(BaseDataset): | |
| def modify_commandline_options(parser, is_train): | |
| parser.set_defaults(load_size=32, crop_size=32, preprocess_crop_padding=0, | |
| preprocess='crop', num_classes=100, use_class_labels=True) | |
| opt, _ = parser.parse_known_args() | |
| assert opt.preprocess == 'crop' and opt.load_size == 32 and opt.crop_size == 32 | |
| return parser | |
| def __init__(self, opt): | |
| self.opt = opt | |
| self.torch_dataset = torchvision.datasets.CIFAR100( | |
| opt.dataroot, train=opt.isTrain, download=True | |
| ) | |
| self.transform = get_transform(self.opt, grayscale=False) | |
| self.class_list = self.create_class_list() | |
| def create_class_list(self): | |
| cache_path = os.path.join(self.opt.dataroot, "%s_classlist.npy" % self.opt.phase) | |
| if os.path.exists(cache_path): | |
| cache = np.load(cache_path) | |
| classlist = {i: [] for i in range(100)} | |
| for i, c in enumerate(cache): | |
| classlist[c].append(i) | |
| return classlist | |
| print("creating cache list of classes...") | |
| classes = np.zeros((len(self.torch_dataset)), dtype=int) | |
| for i in range(len(self.torch_dataset)): | |
| _, class_id = self.torch_dataset[i] | |
| classes[i] = class_id | |
| if i % 100 == 0: | |
| print("%d/%d\r" % (i, len(self.torch_dataset)), end="", flush=True) | |
| np.save(cache_path, classes) | |
| print("cache saved at %s" % cache_path) | |
| return self.create_class_list() | |
| def __getitem__(self, index): | |
| index = index % len(self.torch_dataset) | |
| image, class_id = self.torch_dataset[index] | |
| another_image_index = random.choice(self.class_list[class_id]) | |
| another_image, another_class_id = self.torch_dataset[another_image_index] | |
| assert class_id == another_class_id | |
| return {"real_A": self.transform(image), | |
| "real_B": self.transform(another_image), | |
| "class_A": class_id, "class_B": class_id} | |
| def __len__(self): | |
| return len(self.torch_dataset) | |