|
import copy |
|
import json |
|
import os |
|
|
|
import numpy as np |
|
from scipy.linalg import polar |
|
from scipy.spatial.transform import Rotation |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
from .utils import exists |
|
from .utils.logger import print_log |
|
|
|
|
|
def create_dataset(cfg_dataset): |
|
kwargs = cfg_dataset |
|
name = kwargs.pop('name') |
|
dataset = get_dataset(name)(**kwargs) |
|
print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}") |
|
return dataset |
|
|
|
def get_dataset(name): |
|
return { |
|
'base': PrimitiveDataset, |
|
}[name] |
|
|
|
|
|
SHAPE_CODE = { |
|
'CubeBevel': 0, |
|
'SphereSharp': 1, |
|
'CylinderSharp': 2, |
|
} |
|
|
|
|
|
class PrimitiveDataset(Dataset): |
|
def __init__(self, |
|
pc_dir, |
|
bs_dir, |
|
max_length=144, |
|
range_scale=[0, 1], |
|
range_rotation=[-180, 180], |
|
range_translation=[-1, 1], |
|
rotation_type='euler', |
|
pc_format='pc', |
|
): |
|
self.data_filename = os.listdir(pc_dir) |
|
|
|
self.pc_dir = pc_dir |
|
self.max_length = max_length |
|
self.range_scale = range_scale |
|
self.range_rotation = range_rotation |
|
self.range_translation = range_translation |
|
self.rotation_type = rotation_type |
|
self.pc_format = pc_format |
|
|
|
with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f: |
|
basic_shapes = json.load(f) |
|
|
|
self.typeid_map = { |
|
1101002001034001: 'CubeBevel', |
|
1101002001034010: 'SphereSharp', |
|
1101002001034002: 'CylinderSharp', |
|
} |
|
|
|
def __len__(self): |
|
return len(self.data_filename) |
|
|
|
def __getitem__(self, idx): |
|
pc_file = os.path.join(self.pc_dir, self.data_filename[idx]) |
|
pc = o3d.io.read_point_cloud(pc_file) |
|
|
|
model_data = {} |
|
|
|
points = torch.from_numpy(np.asarray(pc.points)).float() |
|
colors = torch.from_numpy(np.asarray(pc.colors)).float() |
|
normals = torch.from_numpy(np.asarray(pc.normals)).float() |
|
if self.pc_format == 'pc': |
|
model_data['pc'] = torch.concatenate([points, colors], dim=-1).T |
|
elif self.pc_format == 'pn': |
|
model_data['pc'] = torch.concatenate([points, normals], dim=-1) |
|
elif self.pc_format == 'pcn': |
|
model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1) |
|
else: |
|
raise ValueError(f'invalid pc_format: {self.pc_format}') |
|
|
|
return model_data |
|
|
|
|
|
def get_typeid_shapename_mapping(shapenames, config_data): |
|
typeid_map = {} |
|
for info in config_data.values(): |
|
for shapename in shapenames: |
|
if shapename[3:-4] in info['bpPath']: |
|
typeid_map[info['typeId']] = shapename.split('_')[3] |
|
break |
|
return typeid_map |
|
|
|
|
|
def check_valid_range(data, value_range): |
|
lo, hi = value_range |
|
assert hi > lo |
|
return np.logical_and(data >= lo, hi <= hi).all() |
|
|
|
|
|
def quat_to_euler(quat, degree=True): |
|
return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree) |
|
|
|
|
|
def quat_to_rotvec(quat, degree=True): |
|
return Rotation.from_quat(quat).as_rotvec(degrees=degree) |
|
|
|
|
|
def rotate_axis(euler): |
|
trans = np.eye(4, 4) |
|
trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix() |
|
return trans |
|
|
|
|
|
def SRT_quat_to_matrix(scale, quat, translation): |
|
rotation_matrix = Rotation.from_quat(quat).as_matrix() |
|
transform_matrix = np.eye(4) |
|
transform_matrix[:3, :3] = rotation_matrix * scale |
|
transform_matrix[:3, 3] = translation |
|
return transform_matrix |
|
|
|
|
|
def matrix_to_SRT_quat2(transform_matrix): |
|
transform_matrix = np.array(transform_matrix) |
|
translation = transform_matrix[:3, 3] |
|
rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3]) |
|
quat = Rotation.from_matrix(rotation_matrix).as_quat() |
|
scale = np.diag(scale_matrix) |
|
return scale, quat, translation |
|
|
|
|
|
def apply_transform_to_block(block, trans_aug): |
|
precision_loss = False |
|
trans = SRT_quat_to_matrix( |
|
block['data']['scale'], |
|
block['data']['rotation'], |
|
block['data']['location'] |
|
) |
|
|
|
trans = trans_aug @ trans |
|
scale, quat, translation = matrix_to_SRT_quat2(trans) |
|
|
|
trans_rec = SRT_quat_to_matrix(scale, quat, translation) |
|
if not np.allclose(trans, trans_rec, atol=1e-1): |
|
precision_loss = True |
|
return precision_loss, {} |
|
|
|
new_block = copy.deepcopy(block) |
|
new_block['data']['scale'] = scale.tolist() |
|
new_block['data']['rotation'] = quat.tolist() |
|
new_block['data']['location'] = translation.tolist() |
|
return precision_loss, new_block |
|
|