jev-aleks's picture
scenedino init
9e15541
import os
from pathlib import Path
from scenedino.datasets.kitti_360_v2 import KITTI360DatasetV2
from datasets.cityscapes.cityscapes_dataset import CityscapesSeg
from datasets.bdd.bdd_dataset import BDDSeg
from .base_dataset import BaseDataset
from .kitti_360 import KITTI360Dataset
from .old_kitti_360 import OldKITTI360Dataset
from .re10k_dataset import RealEstate10kDataset
import torch
# TODO: make more generic -> no more two function
def make_datasets(config) -> tuple[BaseDataset, BaseDataset]:
dataset_type = config.get("type", "KITTI_360")
match dataset_type:
case "KITTI_360":
if config.get("split_path", None) is None:
train_split_path = None
test_split_path = None
else:
train_split_path = Path(config["split_path"]) / "train_files.txt"
test_split_path = Path(config["split_path"]) / "val_files.txt"
train_dataset = KITTI360Dataset(
data_path=Path(config["data_path"]),
pose_path=Path(config["pose_path"]),
split_path=train_split_path,
target_image_size=tuple(config.get("image_size", (192, 640))),
return_stereo=config.get("data_stereo", True),
return_fisheye=config.get("data_fisheye", True),
frame_count=config.get("data_fc", 3),
return_depth=False,
return_segmentation=config.get("data_segmentation", False),
return_occupancy=False,
keyframe_offset=config.get("keyframe_offset", 0),
dilation=config.get("dilation", 1),
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offsets=config.get("fisheye_offset", [10]),
stereo_offsets=config.get("stereo_offset", [1]),
# color_aug=config.get("color_aug", False),
is_preprocessed=config.get("is_preprocessed", False),
)
test_dataset = KITTI360Dataset(
data_path=Path(config["data_path"]),
pose_path=Path(config["pose_path"]),
split_path=test_split_path,
target_image_size=tuple(config.get("image_size", (192, 640))),
return_stereo=config.get("data_stereo", True),
return_fisheye=config.get("data_fisheye", True),
frame_count=config.get("data_fc", 3),
return_depth=True,
return_segmentation=config.get("data_segmentation", False),
return_occupancy=config.get("occupancy", False),
keyframe_offset=config.get("keyframe_offset", 0),
dilation=config.get("dilation", 1),
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offsets=[10],
stereo_offsets=[1],
is_preprocessed=config.get("is_preprocessed", False),
)
return train_dataset, test_dataset
case "old_KITTI_360":
if config.get("split_path", None) is None:
train_split_path = None
test_split_path = None
else:
train_split_path = Path(config["split_path"]) / "train_files.txt"
test_split_path = Path(config["split_path"]) / "test_files.txt"
train_dataset = OldKITTI360Dataset(
data_path=Path(config["data_path"]),
pose_path=Path(config["pose_path"]),
split_path=train_split_path,
target_image_size=tuple(config.get("image_size", (192, 640))),
return_stereo=config.get("data_stereo", True),
return_fisheye=config.get("data_fisheye", True),
frame_count=config.get("data_fc", 3),
return_depth=False,
return_segmentation=config.get("data_segmentation", False),
keyframe_offset=config.get("keyframe_offset", 0),
dilation=config.get("dilation", 1),
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offset=config.get("fisheye_offset", [10]),
# stereo_offsets=config.get("stereo_offset", [1]),
color_aug=config.get("color_aug", False),
is_preprocessed=config.get("is_preprocessed", False),
)
test_dataset = OldKITTI360Dataset(
data_path=Path(config["data_path"]),
pose_path=Path(config["pose_path"]),
split_path=test_split_path,
target_image_size=tuple(config.get("image_size", (192, 640))),
return_stereo=config.get("data_stereo", True),
return_fisheye=config.get("data_fisheye", True),
frame_count=config.get("data_fc", 3),
return_depth=True,
return_segmentation=config.get("data_segmentation", False),
keyframe_offset=config.get("keyframe_offset", 0),
dilation=config.get("dilation", 1),
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offset=[10],
# stereo_offsets=[1],
is_preprocessed=config.get("is_preprocessed", False),
)
return train_dataset, test_dataset
case "KITTI_360_v2":
if config.get("split_path", None) is None:
train_split_path = None
test_split_path = None
else:
train_split_path = Path(config["split_path"]) / "train_files.txt"
test_split_path = Path(config["split_path"]) / "val_files.txt"
train_dataset = KITTI360DatasetV2(
data_path=Path(config["data_path"]),
pose_path=Path(config["pose_path"]),
split_path=train_split_path,
target_image_size=tuple(config.get("image_size", (192, 640))),
return_stereo=config.get("data_stereo", True),
return_fisheye=config.get("data_fisheye", True),
frame_count=config.get("data_fc", 3),
return_depth=False,
return_segmentation=config.get("data_segmentation", False),
keyframe_offset=config.get("keyframe_offset", 0),
dilation=config.get("dilation", 1),
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offset=config.get("fisheye_offset", [10]),
color_aug=config.get("color_aug", False),
is_preprocessed=config.get("is_preprocessed", False),
)
test_dataset = KITTI360DatasetV2(
data_path=Path(config["data_path"]),
pose_path=Path(config["pose_path"]),
split_path=test_split_path,
target_image_size=tuple(config.get("image_size", (192, 640))),
return_stereo=config.get("data_stereo", True),
return_fisheye=config.get("data_fisheye", True),
frame_count=config.get("data_fc", 3),
return_depth=True,
return_segmentation=config.get("data_segmentation", False),
keyframe_offset=config.get("keyframe_offset", 0),
dilation=config.get("dilation", 1),
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offset=[10],
is_preprocessed=config.get("is_preprocessed", False),
)
return train_dataset, test_dataset
case "Cityscapes_seg":
train_dataset = CityscapesSeg(
root=config["data_path"],
image_set="train",
)
test_dataset = CityscapesSeg(
root=config["data_path"],
image_set="val",
)
return train_dataset, test_dataset
case "RealEstate10K":
if config.get("split_path", None) is None:
train_split_path = None
test_split_path = None
else:
train_split_path = Path(config["split_path"]) / "train_files.txt"
test_split_path = Path(config["split_path"]) / "val_files.txt"
train_dataset = RealEstate10kDataset(
data_path=config["data_path"],
split_path=train_split_path,
image_size=config["image_size"],
)
test_dataset = RealEstate10kDataset(
data_path=config["data_path"],
split_path=test_split_path,
image_size=config["image_size"],
)
return train_dataset, test_dataset
case "BDD_seg":
train_dataset = BDDSeg(
root=config["data_path"],
image_set="train",
)
test_dataset = BDDSeg(
root=config["data_path"],
image_set="val",
)
return train_dataset, test_dataset
case _:
raise NotImplementedError(f"Unsupported dataset type: {type}")
def make_test_dataset(config):
dataset_type = config.get("type", "KITTI_Raw")
match dataset_type:
case "KITTI_360":
test_dataset = OldKITTI360Dataset(
data_path=config["data_path"],
pose_path=config["pose_path"],
split_path=os.path.join(
config.get("split_path", None), "test_files.txt"
),
target_image_size=tuple(config.get("image_size", (192, 640))),
frame_count=config.get("data_fc", 1),
return_stereo=config.get("data_stereo", False),
return_fisheye=config.get("data_fisheye", False),
return_3d_bboxes=config.get("data_3d_bboxes", False),
return_segmentation=config.get("data_segmentation", False),
keyframe_offset=0,
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offset=config.get("fisheye_offset", 1),
dilation=config.get("dilation", 1),
is_preprocessed=config.get("is_preprocessed", False),
)
return test_dataset
case "old_KITTI_360":
test_dataset = OldKITTI360Dataset(
data_path=Path(config["data_path"]),
pose_path=Path(config["pose_path"]),
split_path=os.path.join(
config.get("split_path", None), "test_files.txt"
),
target_image_size=tuple(config.get("image_size", (192, 640))),
return_stereo=config.get("data_stereo", True),
return_fisheye=config.get("data_fisheye", True),
frame_count=config.get("data_fc", 3),
return_depth=True,
return_segmentation=config.get("data_segmentation", False),
keyframe_offset=config.get("keyframe_offset", 0),
dilation=config.get("dilation", 1),
fisheye_rotation=config.get("fisheye_rotation", 0),
fisheye_offset=config.get("fisheye_offset", 1),
# stereo_offsets=[1],
is_preprocessed=config.get("is_preprocessed", False),
)
return test_dataset
case "Cityscapes_seg":
test_dataset = CityscapesSeg(
root=config["data_path"],
image_set="val",
)
return test_dataset
case "RealEstate10K":
test_dataset = RealEstate10kDataset(
data_path=config["data_path"],
image_size=config["image_size"],
)
return test_dataset
case "BDD_seg":
test_dataset = BDDSeg(
root=config["data_path"],
image_set="val",
)
return test_dataset
case _:
raise NotImplementedError(f"Unsupported dataset type: {dataset_type}")