|
""" |
|
Default Datasets |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import os |
|
import glob |
|
import numpy as np |
|
import torch |
|
from copy import deepcopy |
|
from torch.utils.data import Dataset |
|
from collections.abc import Sequence |
|
|
|
from pointcept.utils.logger import get_root_logger |
|
from pointcept.utils.cache import shared_dict |
|
|
|
from .builder import DATASETS, build_dataset |
|
from .transform import Compose, TRANSFORMS |
|
|
|
|
|
@DATASETS.register_module() |
|
class DefaultDataset(Dataset): |
|
VALID_ASSETS = [ |
|
"coord", |
|
"color", |
|
"normal", |
|
"strength", |
|
"segment", |
|
"instance", |
|
"pose", |
|
] |
|
|
|
def __init__( |
|
self, |
|
split="train", |
|
data_root="data/dataset", |
|
transform=None, |
|
test_mode=False, |
|
test_cfg=None, |
|
cache=False, |
|
ignore_index=-1, |
|
loop=1, |
|
): |
|
super(DefaultDataset, self).__init__() |
|
self.data_root = data_root |
|
self.split = split |
|
self.transform = Compose(transform) |
|
self.cache = cache |
|
self.ignore_index = ignore_index |
|
self.loop = ( |
|
loop if not test_mode else 1 |
|
) |
|
self.test_mode = test_mode |
|
self.test_cfg = test_cfg if test_mode else None |
|
|
|
if test_mode: |
|
self.test_voxelize = TRANSFORMS.build(self.test_cfg.voxelize) |
|
self.test_crop = ( |
|
TRANSFORMS.build(self.test_cfg.crop) if self.test_cfg.crop else None |
|
) |
|
self.post_transform = Compose(self.test_cfg.post_transform) |
|
self.aug_transform = [Compose(aug) for aug in self.test_cfg.aug_transform] |
|
|
|
self.data_list = self.get_data_list() |
|
logger = get_root_logger() |
|
logger.info( |
|
"Totally {} x {} samples in {} set.".format( |
|
len(self.data_list), self.loop, split |
|
) |
|
) |
|
|
|
def get_data_list(self): |
|
if isinstance(self.split, str): |
|
data_list = glob.glob(os.path.join(self.data_root, self.split, "*")) |
|
elif isinstance(self.split, Sequence): |
|
data_list = [] |
|
for split in self.split: |
|
data_list += glob.glob(os.path.join(self.data_root, split, "*")) |
|
else: |
|
raise NotImplementedError |
|
return data_list |
|
|
|
def get_data(self, idx): |
|
data_path = self.data_list[idx % len(self.data_list)] |
|
name = self.get_data_name(idx) |
|
if self.cache: |
|
cache_name = f"pointcept-{name}" |
|
return shared_dict(cache_name) |
|
|
|
data_dict = {} |
|
assets = os.listdir(data_path) |
|
for asset in assets: |
|
if not asset.endswith(".npy"): |
|
continue |
|
if asset[:-4] not in self.VALID_ASSETS: |
|
continue |
|
data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset)) |
|
data_dict["name"] = name |
|
|
|
if "coord" in data_dict.keys(): |
|
data_dict["coord"] = data_dict["coord"].astype(np.float32) |
|
|
|
if "color" in data_dict.keys(): |
|
data_dict["color"] = data_dict["color"].astype(np.float32) |
|
|
|
if "normal" in data_dict.keys(): |
|
data_dict["normal"] = data_dict["normal"].astype(np.float32) |
|
|
|
if "segment" in data_dict.keys(): |
|
data_dict["segment"] = data_dict["segment"].reshape([-1]).astype(np.int32) |
|
else: |
|
data_dict["segment"] = ( |
|
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 |
|
) |
|
|
|
if "instance" in data_dict.keys(): |
|
data_dict["instance"] = data_dict["instance"].reshape([-1]).astype(np.int32) |
|
else: |
|
data_dict["instance"] = ( |
|
np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1 |
|
) |
|
return data_dict |
|
|
|
def get_data_name(self, idx): |
|
return os.path.basename(self.data_list[idx % len(self.data_list)]) |
|
|
|
def prepare_train_data(self, idx): |
|
|
|
data_dict = self.get_data(idx) |
|
data_dict = self.transform(data_dict) |
|
return data_dict |
|
|
|
def prepare_test_data(self, idx): |
|
|
|
data_dict = self.get_data(idx) |
|
data_dict = self.transform(data_dict) |
|
result_dict = dict(segment=data_dict.pop("segment"), name=data_dict.pop("name")) |
|
if "origin_segment" in data_dict: |
|
assert "inverse" in data_dict |
|
result_dict["origin_segment"] = data_dict.pop("origin_segment") |
|
result_dict["inverse"] = data_dict.pop("inverse") |
|
|
|
data_dict_list = [] |
|
for aug in self.aug_transform: |
|
data_dict_list.append(aug(deepcopy(data_dict))) |
|
|
|
fragment_list = [] |
|
for data in data_dict_list: |
|
if self.test_voxelize is not None: |
|
data_part_list = self.test_voxelize(data) |
|
else: |
|
data["index"] = np.arange(data["coord"].shape[0]) |
|
data_part_list = [data] |
|
for data_part in data_part_list: |
|
if self.test_crop is not None: |
|
data_part = self.test_crop(data_part) |
|
else: |
|
data_part = [data_part] |
|
fragment_list += data_part |
|
|
|
for i in range(len(fragment_list)): |
|
fragment_list[i] = self.post_transform(fragment_list[i]) |
|
result_dict["fragment_list"] = fragment_list |
|
return result_dict |
|
|
|
def __getitem__(self, idx): |
|
if self.test_mode: |
|
return self.prepare_test_data(idx) |
|
else: |
|
return self.prepare_train_data(idx) |
|
|
|
def __len__(self): |
|
return len(self.data_list) * self.loop |
|
|
|
|
|
@DATASETS.register_module() |
|
class ConcatDataset(Dataset): |
|
def __init__(self, datasets, loop=1): |
|
super(ConcatDataset, self).__init__() |
|
self.datasets = [build_dataset(dataset) for dataset in datasets] |
|
self.loop = loop |
|
self.data_list = self.get_data_list() |
|
logger = get_root_logger() |
|
logger.info( |
|
"Totally {} x {} samples in the concat set.".format( |
|
len(self.data_list), self.loop |
|
) |
|
) |
|
|
|
def get_data_list(self): |
|
data_list = [] |
|
for i in range(len(self.datasets)): |
|
data_list.extend( |
|
zip( |
|
np.ones(len(self.datasets[i])) * i, np.arange(len(self.datasets[i])) |
|
) |
|
) |
|
return data_list |
|
|
|
def get_data(self, idx): |
|
dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] |
|
return self.datasets[dataset_idx][data_idx] |
|
|
|
def get_data_name(self, idx): |
|
dataset_idx, data_idx = self.data_list[idx % len(self.data_list)] |
|
return self.datasets[dataset_idx].get_data_name(data_idx) |
|
|
|
def __getitem__(self, idx): |
|
return self.get_data(idx) |
|
|
|
def __len__(self): |
|
return len(self.data_list) * self.loop |
|
|