Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import, division, print_function | |
| import random | |
| import copy | |
| import io | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import skimage.transform | |
| from collections import Counter | |
| import torch | |
| import torch.utils.data as data | |
| from torch import Tensor | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode as IMode | |
| import utils | |
| class ImgDset(Dataset): | |
| """Customize the data set loading function and prepare low/high resolution image data in advance. | |
| Args: | |
| dataroot (str): Training data set address | |
| image_size (int): High resolution image size | |
| upscale_factor (int): Image magnification | |
| mode (str): Data set loading method, the training data set is for data enhancement, | |
| and the verification data set is not for data enhancement | |
| """ | |
| def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None: | |
| super(ImgDset, self).__init__() | |
| self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)] | |
| if mode == "train": | |
| self.hr_transforms = transforms.Compose([ | |
| transforms.RandomCrop(image_size), | |
| transforms.RandomRotation(90), | |
| transforms.RandomHorizontalFlip(0.5), | |
| ]) | |
| else: | |
| self.hr_transforms = transforms.Resize(image_size) | |
| self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True) | |
| def __getitem__(self, batch_index: int) -> [Tensor, Tensor]: | |
| # Read a batch of image data | |
| image = Image.open(self.filenames[batch_index]) | |
| # Transform image | |
| hr_image = self.hr_transforms(image) | |
| lr_image = self.lr_transforms(hr_image) | |
| # Convert image data into Tensor stream format (PyTorch). | |
| # Note: The range of input and output is between [0, 1] | |
| lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False) | |
| hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False) | |
| return lr_tensor, hr_tensor | |
| def __len__(self) -> int: | |
| return len(self.filenames) | |
| class PairedImages_w_nameList(Dataset): | |
| ''' | |
| can act as supervised or un-supervised based on flists | |
| ''' | |
| def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False): | |
| self.flist1 = flist1 | |
| self.flist2 = flist2 | |
| self.transform1 = transform1 | |
| self.transform2 = transform2 | |
| self.do_aug = do_aug | |
| def __getitem__(self, index): | |
| impath1 = self.flist1[index] | |
| img1 = Image.open(impath1).convert('RGB') | |
| impath2 = self.flist2[index] | |
| img2 = Image.open(impath2).convert('RGB') | |
| img1 = utils.image2tensor(img1, range_norm=False, half=False) | |
| img2 = utils.image2tensor(img2, range_norm=False, half=False) | |
| if self.transform1 is not None: | |
| img1 = self.transform1(img1) | |
| if self.transform2 is not None: | |
| img2 = self.transform2(img2) | |
| return img1, img2 | |
| def __len__(self): | |
| return len(self.flist1) | |
| class PairedImages_w_nameList_npy(Dataset): | |
| ''' | |
| can act as supervised or un-supervised based on flists | |
| ''' | |
| def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False): | |
| self.flist1 = flist1 | |
| self.flist2 = flist2 | |
| self.transform1 = transform1 | |
| self.transform2 = transform2 | |
| self.do_aug = do_aug | |
| def __getitem__(self, index): | |
| impath1 = self.flist1[index] | |
| img1 = np.load(impath1) | |
| impath2 = self.flist2[index] | |
| img2 = np.load(impath2) | |
| if self.transform1 is not None: | |
| img1 = self.transform1(img1) | |
| if self.transform2 is not None: | |
| img2 = self.transform2(img2) | |
| return img1, img2 | |
| def __len__(self): | |
| return len(self.flist1) | |
| # def call_paired(): | |
| # root1='./GOPRO_3840FPS_AVG_3-21/train/blur/' | |
| # root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/' | |
| # flist1=glob.glob(root1+'/*/*.png') | |
| # flist2=glob.glob(root2+'/*/*.png') | |
| # dset = PairedImages_w_nameList(root1,root2,flist1,flist2) | |
| #### KITTI depth | |
| def load_velodyne_points(filename): | |
| """Load 3D point cloud from KITTI file format | |
| (adapted from https://github.com/hunse/kitti) | |
| """ | |
| points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4) | |
| points[:, 3] = 1.0 # homogeneous | |
| return points | |
| def read_calib_file(path): | |
| """Read KITTI calibration file | |
| (from https://github.com/hunse/kitti) | |
| """ | |
| float_chars = set("0123456789.e+- ") | |
| data = {} | |
| with open(path, 'r') as f: | |
| for line in f.readlines(): | |
| key, value = line.split(':', 1) | |
| value = value.strip() | |
| data[key] = value | |
| if float_chars.issuperset(value): | |
| # try to cast to float array | |
| try: | |
| data[key] = np.array(list(map(float, value.split(' ')))) | |
| except ValueError: | |
| # casting error: data[key] already eq. value, so pass | |
| pass | |
| return data | |
| def sub2ind(matrixSize, rowSub, colSub): | |
| """Convert row, col matrix subscripts to linear indices | |
| """ | |
| m, n = matrixSize | |
| return rowSub * (n-1) + colSub - 1 | |
| def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False): | |
| """Generate a depth map from velodyne data | |
| """ | |
| # load calibration files | |
| cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt')) | |
| velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt')) | |
| velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis])) | |
| velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) | |
| # get image shape | |
| im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32) | |
| # compute projection matrix velodyne->image plane | |
| R_cam2rect = np.eye(4) | |
| R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3) | |
| P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4) | |
| P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) | |
| # load velodyne points and remove all behind image plane (approximation) | |
| # each row of the velodyne data is forward, left, up, reflectance | |
| velo = load_velodyne_points(velo_filename) | |
| velo = velo[velo[:, 0] >= 0, :] | |
| # project the points to the camera | |
| velo_pts_im = np.dot(P_velo2im, velo.T).T | |
| velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis] | |
| if vel_depth: | |
| velo_pts_im[:, 2] = velo[:, 0] | |
| # check if in bounds | |
| # use minus 1 to get the exact same value as KITTI matlab code | |
| velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1 | |
| velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1 | |
| val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) | |
| val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0]) | |
| velo_pts_im = velo_pts_im[val_inds, :] | |
| # project to image | |
| depth = np.zeros((im_shape[:2])) | |
| depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] | |
| # find the duplicate points and choose the closest depth | |
| inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) | |
| dupe_inds = [item for item, count in Counter(inds).items() if count > 1] | |
| for dd in dupe_inds: | |
| pts = np.where(inds == dd)[0] | |
| x_loc = int(velo_pts_im[pts[0], 0]) | |
| y_loc = int(velo_pts_im[pts[0], 1]) | |
| depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() | |
| depth[depth < 0] = 0 | |
| return depth | |
| def pil_loader(path): | |
| # open path as file to avoid ResourceWarning | |
| # (https://github.com/python-pillow/Pillow/issues/835) | |
| with open(path, 'rb') as f: | |
| with Image.open(f) as img: | |
| return img.convert('RGB') | |
| class MonoDataset(data.Dataset): | |
| """Superclass for monocular dataloaders | |
| Args: | |
| data_path | |
| filenames | |
| height | |
| width | |
| frame_idxs | |
| num_scales | |
| is_train | |
| img_ext | |
| """ | |
| def __init__(self, | |
| data_path, | |
| filenames, | |
| height, | |
| width, | |
| frame_idxs, | |
| num_scales, | |
| is_train=False, | |
| img_ext='.jpg'): | |
| super(MonoDataset, self).__init__() | |
| self.data_path = data_path | |
| self.filenames = filenames | |
| self.height = height | |
| self.width = width | |
| self.num_scales = num_scales | |
| self.interp = Image.ANTIALIAS | |
| self.frame_idxs = frame_idxs | |
| self.is_train = is_train | |
| self.img_ext = img_ext | |
| self.loader = pil_loader | |
| self.to_tensor = transforms.ToTensor() | |
| # We need to specify augmentations differently in newer versions of torchvision. | |
| # We first try the newer tuple version; if this fails we fall back to scalars | |
| try: | |
| self.brightness = (0.8, 1.2) | |
| self.contrast = (0.8, 1.2) | |
| self.saturation = (0.8, 1.2) | |
| self.hue = (-0.1, 0.1) | |
| transforms.ColorJitter.get_params( | |
| self.brightness, self.contrast, self.saturation, self.hue) | |
| except TypeError: | |
| self.brightness = 0.2 | |
| self.contrast = 0.2 | |
| self.saturation = 0.2 | |
| self.hue = 0.1 | |
| self.resize = {} | |
| for i in range(self.num_scales): | |
| s = 2 ** i | |
| self.resize[i] = transforms.Resize((self.height // s, self.width // s), | |
| interpolation=self.interp) | |
| self.load_depth = self.check_depth() | |
| def preprocess(self, inputs, color_aug): | |
| """Resize colour images to the required scales and augment if required | |
| We create the color_aug object in advance and apply the same augmentation to all | |
| images in this item. This ensures that all images input to the pose network receive the | |
| same augmentation. | |
| """ | |
| for k in list(inputs): | |
| frame = inputs[k] | |
| if "color" in k: | |
| n, im, i = k | |
| for i in range(self.num_scales): | |
| inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)]) | |
| for k in list(inputs): | |
| f = inputs[k] | |
| if "color" in k: | |
| n, im, i = k | |
| inputs[(n, im, i)] = self.to_tensor(f) | |
| inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f)) | |
| def __len__(self): | |
| return len(self.filenames) | |
| def __getitem__(self, index): | |
| """Returns a single training item from the dataset as a dictionary. | |
| Values correspond to torch tensors. | |
| Keys in the dictionary are either strings or tuples: | |
| ("color", <frame_id>, <scale>) for raw colour images, | |
| ("color_aug", <frame_id>, <scale>) for augmented colour images, | |
| ("K", scale) or ("inv_K", scale) for camera intrinsics, | |
| "stereo_T" for camera extrinsics, and | |
| "depth_gt" for ground truth depth maps. | |
| <frame_id> is either: | |
| an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index', | |
| or | |
| "s" for the opposite image in the stereo pair. | |
| <scale> is an integer representing the scale of the image relative to the fullsize image: | |
| -1 images at native resolution as loaded from disk | |
| 0 images resized to (self.width, self.height ) | |
| 1 images resized to (self.width // 2, self.height // 2) | |
| 2 images resized to (self.width // 4, self.height // 4) | |
| 3 images resized to (self.width // 8, self.height // 8) | |
| """ | |
| inputs = {} | |
| do_color_aug = self.is_train and random.random() > 0.5 | |
| do_flip = self.is_train and random.random() > 0.5 | |
| line = self.filenames[index].split() | |
| folder = line[0] | |
| if len(line) == 3: | |
| frame_index = int(line[1]) | |
| else: | |
| frame_index = 0 | |
| if len(line) == 3: | |
| side = line[2] | |
| else: | |
| side = None | |
| for i in self.frame_idxs: | |
| if i == "s": | |
| other_side = {"r": "l", "l": "r"}[side] | |
| inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip) | |
| else: | |
| inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip) | |
| # adjusting intrinsics to match each scale in the pyramid | |
| for scale in range(self.num_scales): | |
| K = self.K.copy() | |
| K[0, :] *= self.width // (2 ** scale) | |
| K[1, :] *= self.height // (2 ** scale) | |
| inv_K = np.linalg.pinv(K) | |
| inputs[("K", scale)] = torch.from_numpy(K) | |
| inputs[("inv_K", scale)] = torch.from_numpy(inv_K) | |
| if do_color_aug: | |
| color_aug = transforms.ColorJitter.get_params( | |
| self.brightness, self.contrast, self.saturation, self.hue) | |
| else: | |
| color_aug = (lambda x: x) | |
| self.preprocess(inputs, color_aug) | |
| for i in self.frame_idxs: | |
| del inputs[("color", i, -1)] | |
| del inputs[("color_aug", i, -1)] | |
| if self.load_depth: | |
| depth_gt = self.get_depth(folder, frame_index, side, do_flip) | |
| inputs["depth_gt"] = np.expand_dims(depth_gt, 0) | |
| inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32)) | |
| if "s" in self.frame_idxs: | |
| stereo_T = np.eye(4, dtype=np.float32) | |
| baseline_sign = -1 if do_flip else 1 | |
| side_sign = -1 if side == "l" else 1 | |
| stereo_T[0, 3] = side_sign * baseline_sign * 0.1 | |
| inputs["stereo_T"] = torch.from_numpy(stereo_T) | |
| return inputs | |
| def get_color(self, folder, frame_index, side, do_flip): | |
| raise NotImplementedError | |
| def check_depth(self): | |
| raise NotImplementedError | |
| def get_depth(self, folder, frame_index, side, do_flip): | |
| raise NotImplementedError | |
| class KITTIDataset(MonoDataset): | |
| """Superclass for different types of KITTI dataset loaders | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(KITTIDataset, self).__init__(*args, **kwargs) | |
| # NOTE: Make sure your intrinsics matrix is *normalized* by the original image size. | |
| # To normalize you need to scale the first row by 1 / image_width and the second row | |
| # by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered. | |
| # If your principal point is far from the center you might need to disable the horizontal | |
| # flip augmentation. | |
| self.K = np.array([[0.58, 0, 0.5, 0], | |
| [0, 1.92, 0.5, 0], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1]], dtype=np.float32) | |
| self.full_res_shape = (1242, 375) | |
| self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3} | |
| def check_depth(self): | |
| line = self.filenames[0].split() | |
| scene_name = line[0] | |
| frame_index = int(line[1]) | |
| velo_filename = os.path.join( | |
| self.data_path, | |
| scene_name, | |
| "velodyne_points/data/{:010d}.bin".format(int(frame_index))) | |
| return os.path.isfile(velo_filename) | |
| def get_color(self, folder, frame_index, side, do_flip): | |
| color = self.loader(self.get_image_path(folder, frame_index, side)) | |
| if do_flip: | |
| color = color.transpose(Image.FLIP_LEFT_RIGHT) | |
| return color | |
| class KITTIDepthDataset(KITTIDataset): | |
| """KITTI dataset which uses the updated ground truth depth maps | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super(KITTIDepthDataset, self).__init__(*args, **kwargs) | |
| def get_image_path(self, folder, frame_index, side): | |
| f_str = "{:010d}{}".format(frame_index, self.img_ext) | |
| image_path = os.path.join( | |
| self.data_path, | |
| folder, | |
| "image_0{}/data".format(self.side_map[side]), | |
| f_str) | |
| return image_path | |
| def get_depth(self, folder, frame_index, side, do_flip): | |
| f_str = "{:010d}.png".format(frame_index) | |
| depth_path = os.path.join( | |
| self.data_path, | |
| folder, | |
| "proj_depth/groundtruth/image_0{}".format(self.side_map[side]), | |
| f_str) | |
| depth_gt = Image.open(depth_path) | |
| depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST) | |
| depth_gt = np.array(depth_gt).astype(np.float32) / 256 | |
| if do_flip: | |
| depth_gt = np.fliplr(depth_gt) | |
| return depth_gt |