ScaleLSD / scalelsd /ssl /datasets /dataset_util.py
Nan Xue
update
4c954ae
raw
history blame
3.25 kB
"""
The interface of initializing different datasets.
"""
from .synthetic_dataset import SyntheticShapes,synthetic_collate_fn
from .wireframe_dataset import WireframeDataset,wireframe_collate_fn
from .yorkurban_dataset import YorkUrbanDataset,yorkurban_collate_fn
from .images_dataset import ImageCollections, images_collate_fn
# from .holicity_dataset import HolicityDataset
# from .merge_dataset import MergeDataset
import torch.utils.data.dataloader as torch_loader
try:
from .official_yorkurban_dataset import YorkUrban
except:
pass
from .nyu_dataset import NYU
from .rdnim_dataset import RDNIM
from .hpatches_dataset import HPatches
def get_dataset(mode="train", dataset_cfg=None, homoadp=False, **kwargs):
""" Initialize different dataset based on a configuration. """
# Check dataset config is given
if dataset_cfg is None:
raise ValueError("[Error] The dataset config is required!")
# Synthetic dataset
if dataset_cfg["dataset_name"] == "synthetic_shape":
dataset = SyntheticShapes(
mode, dataset_cfg
)
# Get the collate_fn
# from sold2.dataset.synthetic_dataset import synthetic_collate_fn
collate_fn = synthetic_collate_fn
# Wireframe dataset
elif dataset_cfg["dataset_name"] == "wireframe":
dataset = WireframeDataset(
mode, dataset_cfg
)
# Get the collate_fn
collate_fn = wireframe_collate_fn
elif dataset_cfg["dataset_name"] == "yorkurban":
dataset = YorkUrbanDataset(
mode, dataset_cfg
)
# Get the collate_fn
collate_fn = yorkurban_collate_fn
# Holicity dataset
elif dataset_cfg["dataset_name"] == "holicity":
dataset = HolicityDataset(
mode, dataset_cfg
)
# Get the collate_fn
from sold2.dataset.holicity_dataset import holicity_collate_fn
collate_fn = holicity_collate_fn
# Dataset merging several datasets in one
elif dataset_cfg["dataset_name"] == "merge":
dataset = MergeDataset(
mode, dataset_cfg
)
# Get the collate_fn
from sold2.dataset.holicity_dataset import holicity_collate_fn
collate_fn = holicity_collate_fn
elif dataset_cfg["dataset_name"] == "general":
dataset = ImageCollections(mode, dataset_cfg, homoadp=homoadp,**kwargs)
collate_fn = images_collate_fn
## for the official YorkUrbanDB
elif dataset_cfg["dataset_name"] == "official_yorkurban":
dataset = YorkUrban(mode, dataset_cfg)
collate_fn = torch_loader.default_collate
## for the NYU_depth_v2
elif dataset_cfg["dataset_name"] == "nyu":
dataset = NYU(mode, dataset_cfg)
collate_fn = torch_loader.default_collate
elif dataset_cfg["dataset_name"] == "rdnim":
dataset = RDNIM(dataset_cfg)
collate_fn = torch_loader.default_collate
elif dataset_cfg["dataset_name"] == "hpatches":
dataset = HPatches(mode, dataset_cfg)
collate_fn = torch_loader.default_collate
else:
raise ValueError(
"[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"])
return dataset, collate_fn