SceneDINO / scenedino /datasets /kitti_360_v2.py
jev-aleks's picture
scenedino init
9e15541
import os
import random
import time
import xml.etree.ElementTree as ET
from collections import Counter, defaultdict
from pathlib import Path
from typing import Optional
from dotdict import dotdict
import yaml
import cv2
import numpy as np
import omegaconf
import torch
import torch.nn.functional as F
from scipy.spatial.transform import Rotation
from torch.utils.data import Dataset
from torchvision.transforms import ColorJitter
from scenedino.common.geometry import estimate_frustum_overlap_2
from scenedino.common.point_sampling import regular_grid
from scenedino.datasets.old_kitti_360 import FisheyeToPinholeSampler, OldKITTI360Dataset
from datasets.kitti_360.annotation import KITTI360Bbox3D
from scenedino.common.augmentation import get_color_aug_fn
def three_to_four(matrix):
new_matrix = torch.eye(4, dtype=matrix.dtype, device=matrix.device)
dims = len(matrix.shape)
new_matrix = new_matrix.view(*([1] * (dims-2)), 4, 4)
new_matrix[..., :3, :3] = matrix
return new_matrix
class FrameSamplingStrategy:
"""Strategy that determines how frames should be sampled around a given base frame.
"""
def sample(self, index, nbr_samples):
raise NotImplementedError
class OverlapFrameSamplingStrategy(FrameSamplingStrategy):
def __init__(self, n_frames, cams=("00", "01"), **kwargs) -> None:
"""Strategy to sample consecutive frames from the same cam.
Args:
n_frames (int): How many frames should be sampled around the given frame
dilation (int, optional): By how much the returned frames should be separated. Defaults to 1.
"""
super().__init__()
self.n_frames = n_frames
self.cams = cams
self.max_samples = kwargs.get("max_samples", 128)
self.min_ratio = kwargs.get("min_ratio", .4)
self.max_steps = kwargs.get("max_steps", 5)
self.ranges_00 = kwargs.get("ranges_00", {
"00": (-10, 20),
# "01": (10, 45),
"02": (10, 50),
"03": (10, 50),
})
self.ranges_01 = kwargs.get("ranges_01", {
# "00": self.ranges_00["01"],
"01": self.ranges_00["00"],
"02": self.ranges_00["02"],
"03": self.ranges_00["03"],
})
self.ranges = {
"00": self.ranges_00,
"01": self.ranges_01,
}
def sample(self, index, nbr_samples, poses, calibs):
poses = torch.tensor(poses, dtype=torch.float32)
ids = []
p_cam = random.random()
if p_cam < .5:
base_cam = "00"
else:
base_cam = "01"
all_ranges = self.ranges[base_cam]
# print(index, nbr_samples, len(poses), poses.shape, poses[index:index+1, :, :])
encoder_frame = (base_cam, index)
encoder_proj = torch.tensor(calibs["K_perspective"] if base_cam in ("00", "01") else calibs["K_fisheye"], dtype=torch.float32)[None, :, :]
encoder_pose = poses[index:index+1, :, :] @ calibs["T_cam_to_pose"][base_cam] # [None, :, :]
off = 1 if random.random() > .5 else -1
target_frame = (base_cam, encoder_frame[1] + off)
ids += [encoder_frame, target_frame]
for i in range(self.max_samples):
if len(ids) >= self.n_frames:
break
c = random.choice(list(all_ranges.keys()))
index_offset = random.choice(range(all_ranges[c][0], all_ranges[c][1]))
off = 1 if random.random() >= .5 else -1
base_frame = (c, max(min(index + index_offset, nbr_samples-1), 0))
target_frame = (c, max(min(base_frame[1] + off, nbr_samples-1), 0))
proj = torch.tensor(calibs["K_perspective"] if c in ("00", "01") else calibs["K_fisheye"], dtype=torch.float32)[None, :, :]
extr = poses[base_frame[1]:base_frame[1]+1, :, :] @ calibs["T_cam_to_pose"][c] #[None, :, :]
# print(proj, extr, encoder_proj, encoder_pose)
# print(proj.shape, extr.shape, encoder_proj.shape, encoder_pose.shape)
overlap = estimate_frustum_overlap_2(proj, extr, encoder_proj, encoder_pose)
overlap = overlap.item()
# print(overlap)
# p_keep = (overlap - self.min_ratio) / (1 - self.min_ratio)
# if p_keep < random.random() and (self.max_samples - i) * 2 > (self.n_frames - len(ids)):
# print("Skip frame:", base_frame)
# continue
if overlap < self.min_ratio and (self.max_samples - i) * 2 > (self.n_frames - len(ids)):
continue
ids += [base_frame, target_frame]
ids = [(cam, max(min(id, nbr_samples-1), 0)) for cam, id in ids]
# print(ids)
return ids
class KITTI360DatasetV2(OldKITTI360Dataset):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._resamplers = {
"00": None,
"01": None,
"02": self._resampler_02,
"03": self._resampler_03,
}
self.frame_sampling_strategy = OverlapFrameSamplingStrategy(n_frames=self.frame_count)
def load_images(self, seq, img_ids):
imgs = []
for cam, id, img_id in img_ids:
path = os.path.join(
self.data_path,
"data_2d_raw",
seq,
f"image_{cam}",
self._perspective_folder if cam in ("00", "01") else self._fisheye_folder,
f"{img_id:010d}.png",
)
# print(path, os.path.exists(path))
img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB,).astype(np.float32) / 255
imgs.append(img)
return imgs
def process_img(
self,
img: np.array,
color_aug_fn=None,
resampler: FisheyeToPinholeSampler = None,
):
if resampler is not None and not self.is_preprocessed:
img = torch.tensor(img).permute(2, 0, 1)
img = resampler.resample(img)
else:
if self.target_image_size:
img = cv2.resize(
img,
(self.target_image_size[1], self.target_image_size[0]),
interpolation=cv2.INTER_LINEAR,
)
img = np.transpose(img, (2, 0, 1))
img = torch.tensor(img)
if color_aug_fn is not None:
img = color_aug_fn(img)
img = img * 2 - 1
return img
def load_depth(self, seq, img_id, cam):
assert cam in ("00", "01")
points = np.fromfile(
os.path.join(
self.data_path,
"data_3d_raw",
seq,
"velodyne_points",
"data",
f"{img_id:010d}.bin",
),
dtype=np.float32,
).reshape(-1, 4)
points[:, 3] = 1.0
T_velo_to_cam = self._calibs["T_velo_to_cam"][cam]
K = self._calibs["K_perspective"]
# project the points to the camera
velo_pts_im = np.dot(K @ T_velo_to_cam[:3, :], points.T).T
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., None]
# the projection is normalized to [-1, 1] -> transform to [0, height-1] x [0, width-1]
velo_pts_im[:, 0] = np.round(
(velo_pts_im[:, 0] * 0.5 + 0.5) * self.target_image_size[1]
)
velo_pts_im[:, 1] = np.round(
(velo_pts_im[:, 1] * 0.5 + 0.5) * self.target_image_size[0]
)
# check if in bounds
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
val_inds = (
val_inds
& (velo_pts_im[:, 0] < self.target_image_size[1])
& (velo_pts_im[:, 1] < self.target_image_size[0])
)
velo_pts_im = velo_pts_im[val_inds, :]
# project to image
depth = np.zeros(self.target_image_size)
depth[
velo_pts_im[:, 1].astype(np.int32), velo_pts_im[:, 0].astype(np.int32)
] = velo_pts_im[:, 2]
# find the duplicate points and choose the closest depth
inds = (
velo_pts_im[:, 1] * (self.target_image_size[1] - 1) + velo_pts_im[:, 0] - 1
)
dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
for dd in dupe_inds:
pts = np.where(inds == dd)[0]
x_loc = int(velo_pts_im[pts[0], 0])
y_loc = int(velo_pts_im[pts[0], 1])
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
depth[depth < 0] = 0
return depth[None, :, :]
def __getitem__(self, index: int):
_start_time = time.time()
if index >= self.length:
raise IndexError()
if self._skip != 0:
index += self._skip
sequence, id, is_right = self._datapoints[index]
seq_len = self._img_ids[sequence].shape[0]
samples = self.frame_sampling_strategy.sample(id, seq_len, self._poses[sequence], self._calibs)
samples = [(cam, id, self.get_img_id_from_id(sequence, id)) for cam, id in samples]
if self.color_aug:
color_aug_fn = get_color_aug_fn(
ColorJitter.get_params(
brightness=(0.8, 1.2),
contrast=(0.8, 1.2),
saturation=(0.8, 1.2),
hue=(-0.1, 0.1),
)
)
else:
color_aug_fn = None
_start_time_loading = time.time()
imgs = self.load_images(sequence, samples)
_loading_time = np.array(time.time() - _start_time_loading)
_start_time_processing = time.time()
imgs = [self.process_img(img, color_aug_fn=color_aug_fn, resampler=self._resamplers[cam]) for ((cam, id, img_id), img) in zip(samples, imgs)]
_processing_time = np.array(time.time() - _start_time_processing)
# These poses are camera to world !!
poses = [self._poses[sequence][id, :, :] @ self._calibs["T_cam_to_pose"][cam] for cam, id, img_id in samples]
projs = [self._calibs["K_perspective"] if cam in ("00", "01") else self._calibs["K_fisheye"] for cam, id, img_id in samples]
ids = [id for cam, id, img_id in samples]
if self.return_depth:
depths = [self.load_depth(sequence, samples[0][2], samples[0][1])]
else:
depths = []
if self.return_3d_bboxes:
bboxes_3d = [self.get_3d_bboxes(sequence, samples[0][2], poses[0], projs[0])]
else:
bboxes_3d = []
if self.return_segmentation:
segs = [self.load_segmentation(sequence, samples[0][2])]
else:
segs = []
_proc_time = np.array(time.time() - _start_time)
data = {
"imgs": imgs,
"projs": projs,
"poses": poses,
"depths": depths,
"ts": ids,
"3d_bboxes": bboxes_3d,
"segs": segs,
"t__get_item__": np.array([_proc_time]),
"index": np.array([index]),
}
return data
def __len__(self) -> int:
return self.length