Spaces:
Running
on
Zero
Running
on
Zero
""" | |
The interface of initializing different datasets. | |
""" | |
from .synthetic_dataset import SyntheticShapes,synthetic_collate_fn | |
from .wireframe_dataset import WireframeDataset,wireframe_collate_fn | |
from .yorkurban_dataset import YorkUrbanDataset,yorkurban_collate_fn | |
from .images_dataset import ImageCollections, images_collate_fn | |
# from .holicity_dataset import HolicityDataset | |
# from .merge_dataset import MergeDataset | |
import torch.utils.data.dataloader as torch_loader | |
try: | |
from .official_yorkurban_dataset import YorkUrban | |
except: | |
pass | |
from .nyu_dataset import NYU | |
from .rdnim_dataset import RDNIM | |
from .hpatches_dataset import HPatches | |
def get_dataset(mode="train", dataset_cfg=None, homoadp=False, **kwargs): | |
""" Initialize different dataset based on a configuration. """ | |
# Check dataset config is given | |
if dataset_cfg is None: | |
raise ValueError("[Error] The dataset config is required!") | |
# Synthetic dataset | |
if dataset_cfg["dataset_name"] == "synthetic_shape": | |
dataset = SyntheticShapes( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
# from sold2.dataset.synthetic_dataset import synthetic_collate_fn | |
collate_fn = synthetic_collate_fn | |
# Wireframe dataset | |
elif dataset_cfg["dataset_name"] == "wireframe": | |
dataset = WireframeDataset( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
collate_fn = wireframe_collate_fn | |
elif dataset_cfg["dataset_name"] == "yorkurban": | |
dataset = YorkUrbanDataset( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
collate_fn = yorkurban_collate_fn | |
# Holicity dataset | |
elif dataset_cfg["dataset_name"] == "holicity": | |
dataset = HolicityDataset( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
from sold2.dataset.holicity_dataset import holicity_collate_fn | |
collate_fn = holicity_collate_fn | |
# Dataset merging several datasets in one | |
elif dataset_cfg["dataset_name"] == "merge": | |
dataset = MergeDataset( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
from sold2.dataset.holicity_dataset import holicity_collate_fn | |
collate_fn = holicity_collate_fn | |
elif dataset_cfg["dataset_name"] == "general": | |
dataset = ImageCollections(mode, dataset_cfg, homoadp=homoadp,**kwargs) | |
collate_fn = images_collate_fn | |
## for the official YorkUrbanDB | |
elif dataset_cfg["dataset_name"] == "official_yorkurban": | |
dataset = YorkUrban(mode, dataset_cfg) | |
collate_fn = torch_loader.default_collate | |
## for the NYU_depth_v2 | |
elif dataset_cfg["dataset_name"] == "nyu": | |
dataset = NYU(mode, dataset_cfg) | |
collate_fn = torch_loader.default_collate | |
elif dataset_cfg["dataset_name"] == "rdnim": | |
dataset = RDNIM(dataset_cfg) | |
collate_fn = torch_loader.default_collate | |
elif dataset_cfg["dataset_name"] == "hpatches": | |
dataset = HPatches(mode, dataset_cfg) | |
collate_fn = torch_loader.default_collate | |
else: | |
raise ValueError( | |
"[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"]) | |
return dataset, collate_fn | |