Spaces:
Runtime error
Runtime error
| import glob | |
| import os | |
| import pickle | |
| import numpy as np | |
| import yaml | |
| from PIL import Image | |
| import xml.etree.ElementTree as ET | |
| from lidm.data.base import DatasetBase | |
| from .annotated_dataset import Annotated3DObjectsDataset | |
| from .conditional_builder.utils import corners_3d_to_2d | |
| from .helper_types import Annotation | |
| from ..utils.lidar_utils import pcd2range, pcd2coord2d, range2pcd | |
| # TODO add annotation categories and semantic categories | |
| CATEGORIES = ['ignore', 'car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', 'motorcyclist', | |
| 'road', 'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', 'trunk', 'terrain', | |
| 'pole', 'traffic-sign'] | |
| CATE2LABEL = {k: v for v, k in enumerate(CATEGORIES)} # 0: invalid, 1~10: categories | |
| LABEL2RGB = np.array([(0, 0, 0), (0, 0, 142), (119, 11, 32), (0, 0, 230), (0, 0, 70), (0, 0, 90), (220, 20, 60), | |
| (255, 0, 0), (0, 0, 110), (128, 64, 128), (250, 170, 160), (244, 35, 232), (230, 150, 140), | |
| (70, 70, 70), (190, 153, 153), (107, 142, 35), (0, 80, 100), (230, 150, 140), (153, 153, 153), | |
| (220, 220, 0)]) | |
| CAMERAS = ['CAM_FRONT'] | |
| BBOX_CATS = ['car', 'people', 'cycle'] | |
| BBOX_CAT2LABEL = {'car': 0, 'truck': 0, 'bus': 0, 'caravan': 0, 'person': 1, 'rider': 2, 'motorcycle': 2, 'bicycle': 2} | |
| # train + test | |
| SEM_KITTI_TRAIN_SET = ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'] | |
| KITTI_TRAIN_SET = SEM_KITTI_TRAIN_SET + ['11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'] | |
| KITTI360_TRAIN_SET = ['00', '02', '04', '05', '06', '07', '09', '10'] + ['08'] # partial test data at '02' sequence | |
| CAM_KITTI360_TRAIN_SET = ['00', '04', '05', '06', '07', '08', '09', '10'] # cam mismatch lidar in '02' | |
| # validation | |
| SEM_KITTI_VAL_SET = KITTI_VAL_SET = ['08'] | |
| CAM_KITTI360_VAL_SET = KITTI360_VAL_SET = ['03'] | |
| class KITTIBase(DatasetBase): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.dataset_name = 'kitti' | |
| self.num_sem_cats = kwargs['dataset_config'].num_sem_cats + 1 | |
| def load_lidar_sweep(path): | |
| scan = np.fromfile(path, dtype=np.float32) | |
| scan = scan.reshape((-1, 4)) | |
| points = scan[:, 0:3] # get xyz | |
| return points | |
| def load_semantic_map(self, path, pcd): | |
| raise NotImplementedError | |
| def load_camera(self, path): | |
| raise NotImplementedError | |
| def __getitem__(self, idx): | |
| example = dict() | |
| data_path = self.data[idx] | |
| # lidar point cloud | |
| sweep = self.load_lidar_sweep(data_path) | |
| if self.lidar_transform: | |
| sweep, _ = self.lidar_transform(sweep, None) | |
| if self.condition_key == 'segmentation': | |
| # semantic maps | |
| proj_range, sem_map = self.load_semantic_map(data_path, sweep) | |
| example[self.condition_key] = sem_map | |
| else: | |
| proj_range, _ = pcd2range(sweep, self.img_size, self.fov, self.depth_range) | |
| proj_range, proj_mask = self.process_scan(proj_range) | |
| example['image'], example['mask'] = proj_range, proj_mask | |
| if self.return_pcd: | |
| reproj_sweep, _, _ = range2pcd(proj_range[0] * .5 + .5, self.fov, self.depth_range, self.depth_scale, self.log_scale) | |
| example['raw'] = sweep | |
| example['reproj'] = reproj_sweep.astype(np.float32) | |
| # image degradation | |
| if self.degradation_transform: | |
| degraded_proj_range = self.degradation_transform(proj_range) | |
| example['degraded_image'] = degraded_proj_range | |
| # cameras | |
| if self.condition_key == 'camera': | |
| cameras = self.load_camera(data_path) | |
| example[self.condition_key] = cameras | |
| return example | |
| class SemanticKITTIBase(KITTIBase): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| assert self.condition_key in ['segmentation'] # for segmentation input only | |
| self.label2rgb = LABEL2RGB | |
| def prepare_data(self): | |
| # read data paths from KITTI | |
| for seq_id in eval('SEM_KITTI_%s_SET' % self.split.upper()): | |
| self.data.extend(glob.glob(os.path.join( | |
| self.data_root, f'dataset/sequences/{seq_id}/velodyne/*.bin'))) | |
| # read label mapping | |
| data_config = yaml.safe_load(open('./data/config/semantic-kitti.yaml', 'r')) | |
| remap_dict = data_config["learning_map"] | |
| max_key = max(remap_dict.keys()) | |
| self.learning_map = np.zeros((max_key + 100), dtype=np.int32) | |
| self.learning_map[list(remap_dict.keys())] = list(remap_dict.values()) | |
| def load_semantic_map(self, path, pcd): | |
| label_path = path.replace('velodyne', 'labels').replace('.bin', '.label') | |
| labels = np.fromfile(label_path, dtype=np.uint32) | |
| labels = labels.reshape((-1)) | |
| labels = labels & 0xFFFF # semantic label in lower half | |
| labels = self.learning_map[labels] | |
| proj_range, sem_map = pcd2range(pcd, self.img_size, self.fov, self.depth_range, labels=labels) | |
| # sem_map = np.expand_dims(sem_map, axis=0).astype(np.int64) | |
| sem_map = sem_map.astype(np.int64) | |
| if self.filtered_map_cats is not None: | |
| sem_map[np.isin(sem_map, self.filtered_map_cats)] = 0 # set filtered category as noise | |
| onehot = np.eye(self.num_sem_cats, dtype=np.float32)[sem_map].transpose(2, 0, 1) | |
| return proj_range, onehot | |
| class SemanticKITTITrain(SemanticKITTIBase): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset/SemanticKITTI', split='train', **kwargs) | |
| class SemanticKITTIValidation(SemanticKITTIBase): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset/SemanticKITTI', split='val', **kwargs) | |
| class KITTI360Base(KITTIBase): | |
| def __init__(self, split_per_view=None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.split_per_view = split_per_view | |
| if self.condition_key == 'camera': | |
| assert self.split_per_view is not None, 'For camera-to-lidar, need to specify split_per_view' | |
| def prepare_data(self): | |
| # read data paths | |
| self.data = [] | |
| if self.condition_key == 'camera': | |
| seq_list = eval('CAM_KITTI360_%s_SET' % self.split.upper()) | |
| else: | |
| seq_list = eval('KITTI360_%s_SET' % self.split.upper()) | |
| for seq_id in seq_list: | |
| self.data.extend(glob.glob(os.path.join( | |
| self.data_root, f'data_3d_raw/2013_05_28_drive_00{seq_id}_sync/velodyne_points/data/*.bin'))) | |
| def random_drop_camera(self, camera_list): | |
| if np.random.rand() < self.aug_config['camera_drop'] and self.split == 'train': | |
| camera_list = [np.zeros_like(c) if i != len(camera_list) // 2 else c for i, c in enumerate(camera_list)] # keep the middle view only | |
| return camera_list | |
| def load_camera(self, path): | |
| camera_path = path.replace('data_3d_raw', 'data_2d_camera').replace('velodyne_points/data', 'image_00/data_rect').replace('.bin', '.png') | |
| camera = np.array(Image.open(camera_path)).astype(np.float32) / 255. | |
| camera = camera.transpose(2, 0, 1) | |
| if self.view_transform: | |
| camera = self.view_transform(camera) | |
| camera_list = np.split(camera, self.split_per_view, axis=2) # split into n chunks as different views | |
| camera_list = self.random_drop_camera(camera_list) | |
| return camera_list | |
| class KITTI360Train(KITTI360Base): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset/KITTI-360', split='train', **kwargs) | |
| class KITTI360Validation(KITTI360Base): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset/KITTI-360', split='val', **kwargs) | |
| class AnnotatedKITTI360Base(Annotated3DObjectsDataset, KITTI360Base): | |
| def __init__(self, **kwargs): | |
| self.id_bbox_dict = dict() | |
| self.id_label_dict = dict() | |
| Annotated3DObjectsDataset.__init__(self, **kwargs) | |
| KITTI360Base.__init__(self, **kwargs) | |
| assert self.condition_key in ['center', 'bbox'] # for annotated images only | |
| def parseOpencvMatrix(node): | |
| rows = int(node.find('rows').text) | |
| cols = int(node.find('cols').text) | |
| data = node.find('data').text.split(' ') | |
| mat = [] | |
| for d in data: | |
| d = d.replace('\n', '') | |
| if len(d) < 1: | |
| continue | |
| mat.append(float(d)) | |
| mat = np.reshape(mat, [rows, cols]) | |
| return mat | |
| def parseVertices(self, child): | |
| transform = self.parseOpencvMatrix(child.find('transform')) | |
| R = transform[:3, :3] | |
| T = transform[:3, 3] | |
| vertices = self.parseOpencvMatrix(child.find('vertices')) | |
| vertices = np.matmul(R, vertices.transpose()).transpose() + T | |
| return vertices | |
| def parse_bbox_xml(self, path): | |
| tree = ET.parse(path) | |
| root = tree.getroot() | |
| bbox_dict = dict() | |
| label_dict = dict() | |
| for child in root: | |
| if child.find('transform') is None: | |
| continue | |
| label_name = child.find('label').text | |
| if label_name not in BBOX_CAT2LABEL: | |
| continue | |
| label = BBOX_CAT2LABEL[label_name] | |
| timestamp = int(child.find('timestamp').text) | |
| # verts = self.parseVertices(child) | |
| verts = self.parseOpencvMatrix(child.find('vertices'))[:8] | |
| if timestamp in bbox_dict: | |
| bbox_dict[timestamp].append(verts) | |
| label_dict[timestamp].append(label) | |
| else: | |
| bbox_dict[timestamp] = [verts] | |
| label_dict[timestamp] = [label] | |
| return bbox_dict, label_dict | |
| def prepare_data(self): | |
| KITTI360Base.prepare_data(self) | |
| self.data = [p for p in self.data if '2013_05_28_drive_0008_sync' not in p] # remove unlabeled sequence 08 | |
| seq_list = eval('KITTI360_%s_SET' % self.split.upper()) | |
| for seq_id in seq_list: | |
| if seq_id != '08': | |
| xml_path = os.path.join(self.data_root, f'data_3d_bboxes/train/2013_05_28_drive_00{seq_id}_sync.xml') | |
| bbox_dict, label_dict = self.parse_bbox_xml(xml_path) | |
| self.id_bbox_dict[seq_id] = bbox_dict | |
| self.id_label_dict[seq_id] = label_dict | |
| def load_annotation(self, path): | |
| seq_id = path.split('/')[-4].split('_')[-2][-2:] | |
| timestamp = int(path.split('/')[-1].replace('.bin', '')) | |
| verts_list = self.id_bbox_dict[seq_id][timestamp] | |
| label_list = self.id_label_dict[seq_id][timestamp] | |
| if self.condition_key == 'bbox': | |
| points = np.stack(verts_list) | |
| elif self.condition_key == 'center': | |
| points = (verts_list[0] + verts_list[6]) / 2. | |
| else: | |
| raise NotImplementedError | |
| labels = np.array([label_list]) | |
| if self.anno_transform: | |
| points, labels = self.anno_transform(points, labels) | |
| return points, labels | |
| def __getitem__(self, idx): | |
| example = dict() | |
| data_path = self.data[idx] | |
| # lidar point cloud | |
| sweep = self.load_lidar_sweep(data_path) | |
| # annotations | |
| bbox_points, bbox_labels = self.load_annotation(data_path) | |
| if self.lidar_transform: | |
| sweep, bbox_points = self.lidar_transform(sweep, bbox_points) | |
| # point cloud -> range | |
| proj_range, _ = pcd2range(sweep, self.img_size, self.fov, self.depth_range) | |
| proj_range, proj_mask = self.process_scan(proj_range) | |
| example['image'], example['mask'] = proj_range, proj_mask | |
| if self.return_pcd: | |
| example['reproj'] = sweep | |
| # annotation -> range | |
| # NOTE: do not need to transform bbox points along with lidar, since their coordinates are based on range-image space instead of 3D space | |
| proj_bbox_points, proj_bbox_labels = pcd2coord2d(bbox_points, self.fov, self.depth_range, labels=bbox_labels) | |
| builder = self.conditional_builders[self.condition_key] | |
| if self.condition_key == 'bbox': | |
| proj_bbox_points = corners_3d_to_2d(proj_bbox_points) | |
| annotations = [Annotation(bbox=bbox.flatten(), category_id=label) for bbox, label in | |
| zip(proj_bbox_points, proj_bbox_labels)] | |
| else: | |
| annotations = [Annotation(center=center, category_id=label) for center, label in | |
| zip(proj_bbox_points, proj_bbox_labels)] | |
| example[self.condition_key] = builder.build(annotations) | |
| return example | |
| class AnnotatedKITTI360Train(AnnotatedKITTI360Base): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset/KITTI-360', split='train', cats=BBOX_CATS, **kwargs) | |
| class AnnotatedKITTI360Validation(AnnotatedKITTI360Base): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset/KITTI-360', split='train', cats=BBOX_CATS, **kwargs) | |
| class KITTIImageBase(KITTIBase): | |
| """ | |
| Range ImageSet only combining KITTI-360 and SemanticKITTI | |
| #Samples (Training): 98014, #Samples (Val): 3511 | |
| """ | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| assert self.condition_key in [None, 'image'] # for image input only | |
| def prepare_data(self): | |
| # read data paths from KITTI-360 | |
| self.data = [] | |
| for seq_id in eval('KITTI360_%s_SET' % self.split.upper()): | |
| self.data.extend(glob.glob(os.path.join( | |
| self.data_root, f'KITTI-360/data_3d_raw/2013_05_28_drive_00{seq_id}_sync/velodyne_points/data/*.bin'))) | |
| # read data paths from KITTI | |
| for seq_id in eval('KITTI_%s_SET' % self.split.upper()): | |
| self.data.extend(glob.glob(os.path.join( | |
| self.data_root, f'SemanticKITTI/dataset/sequences/{seq_id}/velodyne/*.bin'))) | |
| class KITTIImageTrain(KITTIImageBase): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset', split='train', **kwargs) | |
| class KITTIImageValidation(KITTIImageBase): | |
| def __init__(self, **kwargs): | |
| super().__init__(data_root='./dataset', split='val', **kwargs) | |