hyzhou404's picture
private scenes
7f3c2df
from torch.utils.data import Dataset
class HUGSIM_dataset(Dataset):
def __init__(self, views, data_type):
super().__init__()
self.views = views
self.data_type = data_type
if data_type == 'kitti360':
self.gap = 4
elif data_type == 'waymo':
self.gap = 3
elif data_type == 'kitti':
self.gap = 2
else:
self.gap = 6
def __getitem__(self, index):
if index - self.gap >= 0:
prev_index = index-self.gap
else:
prev_index = -1
viewpoint_cam = self.views[index]
gt_image = viewpoint_cam.original_image
if viewpoint_cam.semantic2d is not None:
gt_semantic = viewpoint_cam.semantic2d
else:
gt_semantic = None
if viewpoint_cam.optical_gt is not None:
gt_optical = viewpoint_cam.optical_gt
else:
gt_optical = None
if viewpoint_cam.depth is not None:
gt_depth = viewpoint_cam.depth
else:
gt_depth = None
if viewpoint_cam.mask is not None:
mask = viewpoint_cam.mask
else:
mask = None
return index, prev_index, gt_image, gt_semantic, gt_optical, gt_depth, mask
def __len__(self):
return len(self.views)
def tocuda(ans):
if ans is None:
return None
else:
return ans.cuda()
def hugsim_collate(data):
assert len(data) == 1
return data[0]