Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| from PIL import Image | |
| import os.path as osp | |
| import numpy as np | |
| from torch.utils import data | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| import torchvision.transforms.functional as TF | |
| from .custom_transform import * | |
| class _Coco164kCuratedFew(data.Dataset): | |
| """Base class | |
| This contains fields and methods common to all COCO 164k curated few datasets: | |
| (curated) Coco164kFew_Stuff | |
| (curated) Coco164kFew_Stuff_People | |
| (curated) Coco164kFew_Stuff_Animals | |
| (curated) Coco164kFew_Stuff_People_Animals | |
| """ | |
| def __init__(self, root, img_size, crop_size, split = "train2017"): | |
| super(_Coco164kCuratedFew, self).__init__() | |
| # work out name | |
| self.split = split | |
| self.root = root | |
| self.include_things_labels = False # people | |
| self.incl_animal_things = False # animals | |
| version = 6 | |
| name = "Coco164kFew_Stuff" | |
| if self.include_things_labels and self.incl_animal_things: | |
| name += "_People_Animals" | |
| elif self.include_things_labels: | |
| name += "_People" | |
| elif self.incl_animal_things: | |
| name += "_Animals" | |
| self.name = (name + "_%d" % version) | |
| print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name) | |
| self._set_files() | |
| self.transform = transforms.Compose([ | |
| transforms.Resize(int(img_size)), | |
| transforms.RandomCrop(crop_size)]) | |
| N = len(self.files) | |
| # eqv transform | |
| self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) | |
| self.random_vertical_flip = RandomVerticalFlip(N=N) | |
| self.random_resized_crop = RandomResizedCrop(N=N, res=288) | |
| # photometric transform | |
| self.random_color_brightness = [RandomColorBrightness(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE)] | |
| self.random_color_contrast = [RandomColorContrast(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) | |
| self.random_color_saturation = [RandomColorSaturation(x=0.3, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) | |
| self.random_color_hue = [RandomColorHue(x=0.1, p=0.8, N=N) for _ in range(2)] # Control this later (NOTE) | |
| self.random_gray_scale = [RandomGrayScale(p=0.2, N=N) for _ in range(2)] | |
| self.random_gaussian_blur = [RandomGaussianBlur(sigma=[.1, 2.], p=0.5, N=N) for _ in range(2)] | |
| self.eqv_list = ['random_crop', 'h_flip'] | |
| self.inv_list = ['brightness', 'contrast', 'saturation', 'hue', 'gray', 'blur'] | |
| self.transform_tensor = TensorTransform() | |
| def _set_files(self): | |
| # Create data list by parsing the "images" folder | |
| if self.split in ["train2017", "val2017"]: | |
| file_list = osp.join(self.root, "curated", self.split, self.name + ".txt") | |
| file_list = tuple(open(file_list, "r")) | |
| file_list = [id_.rstrip() for id_ in file_list] | |
| self.files = file_list | |
| print("In total {} images.".format(len(self.files))) | |
| else: | |
| raise ValueError("Invalid split name: {}".format(self.split)) | |
| def transform_eqv(self, indice, image): | |
| if 'random_crop' in self.eqv_list: | |
| image = self.random_resized_crop(indice, image) | |
| if 'h_flip' in self.eqv_list: | |
| image = self.random_horizontal_flip(indice, image) | |
| if 'v_flip' in self.eqv_list: | |
| image = self.random_vertical_flip(indice, image) | |
| return image | |
| def transform_inv(self, index, image, ver): | |
| """ | |
| Hyperparameters same as MoCo v2. | |
| (https://github.com/facebookresearch/moco/blob/master/main_moco.py) | |
| """ | |
| if 'brightness' in self.inv_list: | |
| image = self.random_color_brightness[ver](index, image) | |
| if 'contrast' in self.inv_list: | |
| image = self.random_color_contrast[ver](index, image) | |
| if 'saturation' in self.inv_list: | |
| image = self.random_color_saturation[ver](index, image) | |
| if 'hue' in self.inv_list: | |
| image = self.random_color_hue[ver](index, image) | |
| if 'gray' in self.inv_list: | |
| image = self.random_gray_scale[ver](index, image) | |
| if 'blur' in self.inv_list: | |
| image = self.random_gaussian_blur[ver](index, image) | |
| return image | |
| def transform_image(self, index, image): | |
| image1 = self.transform_inv(index, image, 0) | |
| image1 = self.transform_tensor(image) | |
| image2 = self.transform_inv(index, image, 1) | |
| #image2 = TF.resize(image2, self.crop_size, Image.BILINEAR) | |
| image2 = self.transform_tensor(image2) | |
| return image1, image2 | |
| def __getitem__(self, index): | |
| # same as _Coco164k | |
| # Set paths | |
| image_id = self.files[index] | |
| image_path = osp.join(self.root, "images", self.split, image_id + ".jpg") | |
| # Load an image | |
| ori_img = Image.open(image_path) | |
| ori_img = self.transform(ori_img) | |
| image1, image2 = self.transform_image(index, ori_img) | |
| if image1.shape[0] < 3: | |
| image1 = image1.repeat(3, 1, 1) | |
| if image2.shape[0] < 3: | |
| image2 = image2.repeat(3, 1, 1) | |
| rets = [] | |
| rets.append(image1) | |
| rets.append(image2) | |
| rets.append(index) | |
| return rets | |
| def __len__(self): | |
| return len(self.files) | |