File size: 3,541 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
"""
Waymo dataset
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import os
import numpy as np
import glob
from .builder import DATASETS
from .defaults import DefaultDataset
@DATASETS.register_module()
class WaymoDataset(DefaultDataset):
def __init__(
self,
timestamp=(0,),
reference_label=True,
timing_embedding=False,
**kwargs,
):
super().__init__(**kwargs)
assert timestamp[0] == 0
self.timestamp = timestamp
self.reference_label = reference_label
self.timing_embedding = timing_embedding
self.data_list = sorted(self.data_list)
_, self.sequence_offset, self.sequence_index = np.unique(
[os.path.dirname(data) for data in self.data_list],
return_index=True,
return_inverse=True,
)
self.sequence_offset = np.append(self.sequence_offset, len(self.data_list))
def get_data_list(self):
if isinstance(self.split, str):
self.split = [self.split]
data_list = []
for split in self.split:
data_list += glob.glob(os.path.join(self.data_root, split, "*", "*"))
return data_list
@staticmethod
def align_pose(coord, pose, target_pose):
coord = np.hstack((coord, np.ones_like(coord[:, :1])))
pose_align = np.matmul(np.linalg.inv(target_pose), pose)
coord = (pose_align @ coord.T).T[:, :3]
return coord
def get_single_frame(self, idx):
return super().get_data(idx)
def get_data(self, idx):
idx = idx % len(self.data_list)
if self.timestamp == (0,):
return self.get_single_frame(idx)
sequence_index = self.sequence_index[idx]
lower, upper = self.sequence_offset[[sequence_index, sequence_index + 1]]
major_frame = self.get_single_frame(idx)
name = major_frame.pop("name")
target_pose = major_frame.pop("pose")
for key in major_frame.keys():
major_frame[key] = [major_frame[key]]
for timestamp in self.timestamp[1:]:
refer_idx = timestamp + idx
if refer_idx < lower or upper <= refer_idx:
continue
refer_frame = self.get_single_frame(refer_idx)
refer_frame.pop("name")
pose = refer_frame.pop("pose")
refer_frame["coord"] = self.align_pose(
refer_frame["coord"], pose, target_pose
)
if not self.reference_label:
refer_frame["segment"] = (
np.ones_like(refer_frame["segment"]) * self.ignore_index
)
if self.timing_embedding:
refer_frame["strength"] = np.hstack(
(
refer_frame["strength"],
np.ones_like(refer_frame["strength"]) * timestamp,
)
)
for key in major_frame.keys():
major_frame[key].append(refer_frame[key])
for key in major_frame.keys():
major_frame[key] = np.concatenate(major_frame[key], axis=0)
major_frame["name"] = name
return major_frame
def get_data_name(self, idx):
file_path = self.data_list[idx % len(self.data_list)]
sequence_path, frame_name = os.path.split(file_path)
sequence_name = os.path.basename(sequence_path)
data_name = f"{sequence_name}_{frame_name}"
return data_name
|