Spaces:
Paused
Paused
from torch.utils.data import Dataset | |
class HUGSIM_dataset(Dataset): | |
def __init__(self, views, data_type): | |
super().__init__() | |
self.views = views | |
self.data_type = data_type | |
if data_type == 'kitti360': | |
self.gap = 4 | |
elif data_type == 'waymo': | |
self.gap = 3 | |
elif data_type == 'kitti': | |
self.gap = 2 | |
else: | |
self.gap = 6 | |
def __getitem__(self, index): | |
if index - self.gap >= 0: | |
prev_index = index-self.gap | |
else: | |
prev_index = -1 | |
viewpoint_cam = self.views[index] | |
gt_image = viewpoint_cam.original_image | |
if viewpoint_cam.semantic2d is not None: | |
gt_semantic = viewpoint_cam.semantic2d | |
else: | |
gt_semantic = None | |
if viewpoint_cam.optical_gt is not None: | |
gt_optical = viewpoint_cam.optical_gt | |
else: | |
gt_optical = None | |
if viewpoint_cam.depth is not None: | |
gt_depth = viewpoint_cam.depth | |
else: | |
gt_depth = None | |
if viewpoint_cam.mask is not None: | |
mask = viewpoint_cam.mask | |
else: | |
mask = None | |
return index, prev_index, gt_image, gt_semantic, gt_optical, gt_depth, mask | |
def __len__(self): | |
return len(self.views) | |
def tocuda(ans): | |
if ans is None: | |
return None | |
else: | |
return ans.cuda() | |
def hugsim_collate(data): | |
assert len(data) == 1 | |
return data[0] |