File size: 3,169 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from dataclasses import fields
from typing import Callable
from torch.utils.data import Dataset, ConcatDataset
import bisect
from ..misc.step_tracker import StepTracker
from .types import Stage
from .view_sampler import get_view_sampler
from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfgWrapper
from .dataset_scannetpp import DatasetScannetpp, DatasetScannetppCfgWrapper
from .dataset_co3d import DatasetCo3d, DatasetCo3dCfgWrapper
DATASETS: dict[str, Dataset] = {
"co3d": DatasetCo3d,
"scannetpp": DatasetScannetpp,
"dl3dv": DatasetDL3DV,
}
DatasetCfgWrapper = DatasetDL3DVCfgWrapper | DatasetScannetppCfgWrapper | DatasetCo3dCfgWrapper
class TestDatasetWarpper(Dataset):
def __init__(self, dataset: Dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset[(idx, self.dataset.view_sampler.num_context_views, self.dataset.cfg.input_image_shape[1] // 14)] # fake parameters here, to fit the input of dataset
def __len__(self):
return len(self.dataset)
class CustomConcatDataset(ConcatDataset):
def __getitem__(self, idx_tuple):
if isinstance(idx_tuple, list):
idx_tuple = idx_tuple[0]
idx = idx_tuple[0]
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][(sample_idx, idx_tuple[1], idx_tuple[2])]
def get_dataset(
cfgs: list[DatasetCfgWrapper],
stage: Stage,
step_tracker: StepTracker | None,
dataset_shim: Callable[[Dataset, str], Dataset]
) -> list[Dataset]:
datasets = []
if stage != "test":
if stage == "val":
cfgs = [cfgs[0]]
for cfg in cfgs:
(field,) = fields(type(cfg))
cfg = getattr(cfg, field.name)
view_sampler = get_view_sampler(
cfg.view_sampler,
stage,
cfg.overfit_to_scene is not None,
cfg.cameras_are_circular,
step_tracker,
)
dataset = DATASETS[cfg.name](cfg, stage, view_sampler)
dataset = dataset_shim(dataset, stage)
datasets.append(dataset)
return CustomConcatDataset(datasets), datasets
elif stage == "test":
assert len(cfgs) == 1
cfg = cfgs[0]
(field,) = fields(type(cfg))
cfg = getattr(cfg, field.name)
view_sampler = get_view_sampler(
cfg.view_sampler,
stage,
cfg.overfit_to_scene is not None,
cfg.cameras_are_circular,
step_tracker,
)
dataset = DATASETS[cfg.name](cfg, stage, view_sampler)
dataset = dataset_shim(dataset, stage)
return TestDatasetWarpper(dataset)
else:
NotImplementedError(f"Stage {stage} is not supported")
|