import os import random import time import xml.etree.ElementTree as ET from collections import Counter, defaultdict from pathlib import Path from typing import Optional from dotdict import dotdict import yaml import cv2 import numpy as np import omegaconf import torch import torch.nn.functional as F from scipy.spatial.transform import Rotation from torch.utils.data import Dataset from torchvision.transforms import ColorJitter from scenedino.common.geometry import estimate_frustum_overlap_2 from scenedino.common.point_sampling import regular_grid from scenedino.datasets.old_kitti_360 import FisheyeToPinholeSampler, OldKITTI360Dataset from datasets.kitti_360.annotation import KITTI360Bbox3D from scenedino.common.augmentation import get_color_aug_fn def three_to_four(matrix): new_matrix = torch.eye(4, dtype=matrix.dtype, device=matrix.device) dims = len(matrix.shape) new_matrix = new_matrix.view(*([1] * (dims-2)), 4, 4) new_matrix[..., :3, :3] = matrix return new_matrix class FrameSamplingStrategy: """Strategy that determines how frames should be sampled around a given base frame. """ def sample(self, index, nbr_samples): raise NotImplementedError class OverlapFrameSamplingStrategy(FrameSamplingStrategy): def __init__(self, n_frames, cams=("00", "01"), **kwargs) -> None: """Strategy to sample consecutive frames from the same cam. Args: n_frames (int): How many frames should be sampled around the given frame dilation (int, optional): By how much the returned frames should be separated. Defaults to 1. """ super().__init__() self.n_frames = n_frames self.cams = cams self.max_samples = kwargs.get("max_samples", 128) self.min_ratio = kwargs.get("min_ratio", .4) self.max_steps = kwargs.get("max_steps", 5) self.ranges_00 = kwargs.get("ranges_00", { "00": (-10, 20), # "01": (10, 45), "02": (10, 50), "03": (10, 50), }) self.ranges_01 = kwargs.get("ranges_01", { # "00": self.ranges_00["01"], "01": self.ranges_00["00"], "02": self.ranges_00["02"], "03": self.ranges_00["03"], }) self.ranges = { "00": self.ranges_00, "01": self.ranges_01, } def sample(self, index, nbr_samples, poses, calibs): poses = torch.tensor(poses, dtype=torch.float32) ids = [] p_cam = random.random() if p_cam < .5: base_cam = "00" else: base_cam = "01" all_ranges = self.ranges[base_cam] # print(index, nbr_samples, len(poses), poses.shape, poses[index:index+1, :, :]) encoder_frame = (base_cam, index) encoder_proj = torch.tensor(calibs["K_perspective"] if base_cam in ("00", "01") else calibs["K_fisheye"], dtype=torch.float32)[None, :, :] encoder_pose = poses[index:index+1, :, :] @ calibs["T_cam_to_pose"][base_cam] # [None, :, :] off = 1 if random.random() > .5 else -1 target_frame = (base_cam, encoder_frame[1] + off) ids += [encoder_frame, target_frame] for i in range(self.max_samples): if len(ids) >= self.n_frames: break c = random.choice(list(all_ranges.keys())) index_offset = random.choice(range(all_ranges[c][0], all_ranges[c][1])) off = 1 if random.random() >= .5 else -1 base_frame = (c, max(min(index + index_offset, nbr_samples-1), 0)) target_frame = (c, max(min(base_frame[1] + off, nbr_samples-1), 0)) proj = torch.tensor(calibs["K_perspective"] if c in ("00", "01") else calibs["K_fisheye"], dtype=torch.float32)[None, :, :] extr = poses[base_frame[1]:base_frame[1]+1, :, :] @ calibs["T_cam_to_pose"][c] #[None, :, :] # print(proj, extr, encoder_proj, encoder_pose) # print(proj.shape, extr.shape, encoder_proj.shape, encoder_pose.shape) overlap = estimate_frustum_overlap_2(proj, extr, encoder_proj, encoder_pose) overlap = overlap.item() # print(overlap) # p_keep = (overlap - self.min_ratio) / (1 - self.min_ratio) # if p_keep < random.random() and (self.max_samples - i) * 2 > (self.n_frames - len(ids)): # print("Skip frame:", base_frame) # continue if overlap < self.min_ratio and (self.max_samples - i) * 2 > (self.n_frames - len(ids)): continue ids += [base_frame, target_frame] ids = [(cam, max(min(id, nbr_samples-1), 0)) for cam, id in ids] # print(ids) return ids class KITTI360DatasetV2(OldKITTI360Dataset): def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs) self._resamplers = { "00": None, "01": None, "02": self._resampler_02, "03": self._resampler_03, } self.frame_sampling_strategy = OverlapFrameSamplingStrategy(n_frames=self.frame_count) def load_images(self, seq, img_ids): imgs = [] for cam, id, img_id in img_ids: path = os.path.join( self.data_path, "data_2d_raw", seq, f"image_{cam}", self._perspective_folder if cam in ("00", "01") else self._fisheye_folder, f"{img_id:010d}.png", ) # print(path, os.path.exists(path)) img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB,).astype(np.float32) / 255 imgs.append(img) return imgs def process_img( self, img: np.array, color_aug_fn=None, resampler: FisheyeToPinholeSampler = None, ): if resampler is not None and not self.is_preprocessed: img = torch.tensor(img).permute(2, 0, 1) img = resampler.resample(img) else: if self.target_image_size: img = cv2.resize( img, (self.target_image_size[1], self.target_image_size[0]), interpolation=cv2.INTER_LINEAR, ) img = np.transpose(img, (2, 0, 1)) img = torch.tensor(img) if color_aug_fn is not None: img = color_aug_fn(img) img = img * 2 - 1 return img def load_depth(self, seq, img_id, cam): assert cam in ("00", "01") points = np.fromfile( os.path.join( self.data_path, "data_3d_raw", seq, "velodyne_points", "data", f"{img_id:010d}.bin", ), dtype=np.float32, ).reshape(-1, 4) points[:, 3] = 1.0 T_velo_to_cam = self._calibs["T_velo_to_cam"][cam] K = self._calibs["K_perspective"] # project the points to the camera velo_pts_im = np.dot(K @ T_velo_to_cam[:3, :], points.T).T velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., None] # the projection is normalized to [-1, 1] -> transform to [0, height-1] x [0, width-1] velo_pts_im[:, 0] = np.round( (velo_pts_im[:, 0] * 0.5 + 0.5) * self.target_image_size[1] ) velo_pts_im[:, 1] = np.round( (velo_pts_im[:, 1] * 0.5 + 0.5) * self.target_image_size[0] ) # check if in bounds val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) val_inds = ( val_inds & (velo_pts_im[:, 0] < self.target_image_size[1]) & (velo_pts_im[:, 1] < self.target_image_size[0]) ) velo_pts_im = velo_pts_im[val_inds, :] # project to image depth = np.zeros(self.target_image_size) depth[ velo_pts_im[:, 1].astype(np.int32), velo_pts_im[:, 0].astype(np.int32) ] = velo_pts_im[:, 2] # find the duplicate points and choose the closest depth inds = ( velo_pts_im[:, 1] * (self.target_image_size[1] - 1) + velo_pts_im[:, 0] - 1 ) dupe_inds = [item for item, count in Counter(inds).items() if count > 1] for dd in dupe_inds: pts = np.where(inds == dd)[0] x_loc = int(velo_pts_im[pts[0], 0]) y_loc = int(velo_pts_im[pts[0], 1]) depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() depth[depth < 0] = 0 return depth[None, :, :] def __getitem__(self, index: int): _start_time = time.time() if index >= self.length: raise IndexError() if self._skip != 0: index += self._skip sequence, id, is_right = self._datapoints[index] seq_len = self._img_ids[sequence].shape[0] samples = self.frame_sampling_strategy.sample(id, seq_len, self._poses[sequence], self._calibs) samples = [(cam, id, self.get_img_id_from_id(sequence, id)) for cam, id in samples] if self.color_aug: color_aug_fn = get_color_aug_fn( ColorJitter.get_params( brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=(-0.1, 0.1), ) ) else: color_aug_fn = None _start_time_loading = time.time() imgs = self.load_images(sequence, samples) _loading_time = np.array(time.time() - _start_time_loading) _start_time_processing = time.time() imgs = [self.process_img(img, color_aug_fn=color_aug_fn, resampler=self._resamplers[cam]) for ((cam, id, img_id), img) in zip(samples, imgs)] _processing_time = np.array(time.time() - _start_time_processing) # These poses are camera to world !! poses = [self._poses[sequence][id, :, :] @ self._calibs["T_cam_to_pose"][cam] for cam, id, img_id in samples] projs = [self._calibs["K_perspective"] if cam in ("00", "01") else self._calibs["K_fisheye"] for cam, id, img_id in samples] ids = [id for cam, id, img_id in samples] if self.return_depth: depths = [self.load_depth(sequence, samples[0][2], samples[0][1])] else: depths = [] if self.return_3d_bboxes: bboxes_3d = [self.get_3d_bboxes(sequence, samples[0][2], poses[0], projs[0])] else: bboxes_3d = [] if self.return_segmentation: segs = [self.load_segmentation(sequence, samples[0][2])] else: segs = [] _proc_time = np.array(time.time() - _start_time) data = { "imgs": imgs, "projs": projs, "poses": poses, "depths": depths, "ts": ids, "3d_bboxes": bboxes_3d, "segs": segs, "t__get_item__": np.array([_proc_time]), "index": np.array([index]), } return data def __len__(self) -> int: return self.length