|
from dataclasses import fields |
|
from typing import Callable |
|
from torch.utils.data import Dataset, ConcatDataset |
|
import bisect |
|
|
|
from ..misc.step_tracker import StepTracker |
|
from .types import Stage |
|
from .view_sampler import get_view_sampler |
|
from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfgWrapper |
|
from .dataset_scannetpp import DatasetScannetpp, DatasetScannetppCfgWrapper |
|
from .dataset_co3d import DatasetCo3d, DatasetCo3dCfgWrapper |
|
|
|
DATASETS: dict[str, Dataset] = { |
|
"co3d": DatasetCo3d, |
|
"scannetpp": DatasetScannetpp, |
|
"dl3dv": DatasetDL3DV, |
|
} |
|
|
|
DatasetCfgWrapper = DatasetDL3DVCfgWrapper | DatasetScannetppCfgWrapper | DatasetCo3dCfgWrapper |
|
|
|
class TestDatasetWarpper(Dataset): |
|
def __init__(self, dataset: Dataset): |
|
self.dataset = dataset |
|
|
|
def __getitem__(self, idx): |
|
|
|
return self.dataset[(idx, self.dataset.view_sampler.num_context_views, self.dataset.cfg.input_image_shape[1] // 14)] |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
|
|
|
|
class CustomConcatDataset(ConcatDataset): |
|
|
|
def __getitem__(self, idx_tuple): |
|
|
|
if isinstance(idx_tuple, list): |
|
idx_tuple = idx_tuple[0] |
|
|
|
idx = idx_tuple[0] |
|
if idx < 0: |
|
if -idx > len(self): |
|
raise ValueError("absolute value of index should not exceed dataset length") |
|
idx = len(self) + idx |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
return self.datasets[dataset_idx][(sample_idx, idx_tuple[1], idx_tuple[2])] |
|
|
|
|
|
def get_dataset( |
|
cfgs: list[DatasetCfgWrapper], |
|
stage: Stage, |
|
step_tracker: StepTracker | None, |
|
dataset_shim: Callable[[Dataset, str], Dataset] |
|
) -> list[Dataset]: |
|
datasets = [] |
|
if stage != "test": |
|
if stage == "val": |
|
cfgs = [cfgs[0]] |
|
for cfg in cfgs: |
|
(field,) = fields(type(cfg)) |
|
cfg = getattr(cfg, field.name) |
|
|
|
view_sampler = get_view_sampler( |
|
cfg.view_sampler, |
|
stage, |
|
cfg.overfit_to_scene is not None, |
|
cfg.cameras_are_circular, |
|
step_tracker, |
|
) |
|
dataset = DATASETS[cfg.name](cfg, stage, view_sampler) |
|
dataset = dataset_shim(dataset, stage) |
|
datasets.append(dataset) |
|
|
|
return CustomConcatDataset(datasets), datasets |
|
elif stage == "test": |
|
assert len(cfgs) == 1 |
|
cfg = cfgs[0] |
|
(field,) = fields(type(cfg)) |
|
cfg = getattr(cfg, field.name) |
|
|
|
view_sampler = get_view_sampler( |
|
cfg.view_sampler, |
|
stage, |
|
cfg.overfit_to_scene is not None, |
|
cfg.cameras_are_circular, |
|
step_tracker, |
|
) |
|
dataset = DATASETS[cfg.name](cfg, stage, view_sampler) |
|
dataset = dataset_shim(dataset, stage) |
|
|
|
return TestDatasetWarpper(dataset) |
|
else: |
|
NotImplementedError(f"Stage {stage} is not supported") |
|
|