Spaces:
Running
on
Zero
Running
on
Zero
| import trimesh | |
| import numpy as np | |
| from .data_utils import discretize, undiscretize | |
| def patchified_mesh(mesh: trimesh.Trimesh, special_token = -2, fix_orient=True): | |
| sequence = [] | |
| unvisited = np.full(len(mesh.faces), True) | |
| degrees = mesh.vertex_degree.copy() | |
| # with fix_orient=True, the normal would be correct. | |
| # but this may increase the difficulty for learning. | |
| if fix_orient: | |
| face_orient = {} | |
| for ind, face in enumerate(mesh.faces): | |
| v0, v1, v2 = face[0], face[1], face[2] | |
| face_orient['{}-{}-{}'.format(v0, v1, v2)] = True | |
| face_orient['{}-{}-{}'.format(v1, v2, v0)] = True | |
| face_orient['{}-{}-{}'.format(v2, v0, v1)] = True | |
| face_orient['{}-{}-{}'.format(v2, v1, v0)] = False | |
| face_orient['{}-{}-{}'.format(v1, v0, v2)] = False | |
| face_orient['{}-{}-{}'.format(v0, v2, v1)] = False | |
| while sum(unvisited): | |
| unvisited_faces = mesh.faces[unvisited] | |
| # select the patch center | |
| cur_face = unvisited_faces[0] | |
| max_deg_vertex_id = np.argmax(degrees[cur_face]) | |
| max_deg_vertex = cur_face[max_deg_vertex_id] | |
| # find all connected faces | |
| selected_faces = [] | |
| for face_idx in mesh.vertex_faces[max_deg_vertex]: | |
| if face_idx != -1 and unvisited[face_idx]: | |
| face = mesh.faces[face_idx] | |
| u, v = sorted([vertex for vertex in face if vertex != max_deg_vertex]) | |
| selected_faces.append([u, v, face_idx]) | |
| face_patch = set() | |
| selected_faces = sorted(selected_faces) | |
| # select the start vertex, select it if it only appears once (the start or end), | |
| # else select the lowest index | |
| cnt = {} | |
| for u, v, _ in selected_faces: | |
| cnt[u] = cnt.get(u, 0) + 1 | |
| cnt[v] = cnt.get(v, 0) + 1 | |
| starts = [] | |
| for vertex, num in cnt.items(): | |
| if num == 1: | |
| starts.append(vertex) | |
| start_idx = min(starts) if len(starts) else selected_faces[0][0] | |
| res = [start_idx] | |
| while len(res) <= len(selected_faces): | |
| vertex = res[-1] | |
| for u_i, v_i, face_idx_i in selected_faces: | |
| if face_idx_i not in face_patch and vertex in (u_i, v_i): | |
| u_i, v_i = (u_i, v_i) if vertex == u_i else (v_i, u_i) | |
| res.append(v_i) | |
| face_patch.add(face_idx_i) | |
| break | |
| if res[-1] == vertex: | |
| break | |
| if fix_orient and len(res) >= 2 and not face_orient['{}-{}-{}'.format(max_deg_vertex, res[0], res[1])]: | |
| res = res[::-1] | |
| # reduce the degree of related vertices and mark the visited faces | |
| degrees[max_deg_vertex] = len(selected_faces) - len(res) + 1 | |
| for pos_idx, vertex in enumerate(res): | |
| if pos_idx in [0, len(res) - 1]: | |
| degrees[vertex] -= 1 | |
| else: | |
| degrees[vertex] -= 2 | |
| for face_idx in face_patch: | |
| unvisited[face_idx] = False | |
| sequence.extend( | |
| [mesh.vertices[max_deg_vertex]] + | |
| [mesh.vertices[vertex_idx] for vertex_idx in res] + | |
| [[special_token] * 3] | |
| ) | |
| assert sum(degrees) == 0, 'All degrees should be zero' | |
| return np.array(sequence) | |
| def get_block_representation( | |
| sequence, | |
| block_size=8, | |
| offset_size=16, | |
| block_compressed=True, | |
| special_token=-2, | |
| use_special_block=True | |
| ): | |
| ''' | |
| convert coordinates from Cartesian system to block indexes. | |
| ''' | |
| special_block_base = block_size**3 + offset_size**3 | |
| # prepare coordinates | |
| sp_mask = sequence != special_token | |
| sp_mask = np.all(sp_mask, axis=1) | |
| coords = sequence[sp_mask].reshape(-1, 3) | |
| coords = discretize(coords) | |
| # convert [x, y, z] to [block_id, offset_id] | |
| block_id = coords // offset_size | |
| block_id = block_id[:, 0] * block_size**2 + block_id[:, 1] * block_size + block_id[:, 2] | |
| offset_id = coords % offset_size | |
| offset_id = offset_id[:, 0] * offset_size**2 + offset_id[:, 1] * offset_size + offset_id[:, 2] | |
| offset_id += block_size**3 | |
| block_coords = np.concatenate([block_id[..., None], offset_id[..., None]], axis=-1).astype(np.int64) | |
| sequence[:, :2][sp_mask] = block_coords | |
| sequence = sequence[:, :2] | |
| # convert to codes | |
| codes = [] | |
| cur_block_id = sequence[0, 0] | |
| codes.append(cur_block_id) | |
| for i in range(len(sequence)): | |
| if sequence[i, 0] == special_token: | |
| if not use_special_block: | |
| codes.append(special_token) | |
| cur_block_id = special_token | |
| elif sequence[i, 0] == cur_block_id: | |
| if block_compressed: | |
| codes.append(sequence[i, 1]) | |
| else: | |
| codes.extend([sequence[i, 0], sequence[i, 1]]) | |
| else: | |
| if use_special_block and cur_block_id == special_token: | |
| block_id = sequence[i, 0] + special_block_base | |
| else: | |
| block_id = sequence[i, 0] | |
| codes.extend([block_id, sequence[i, 1]]) | |
| cur_block_id = block_id | |
| codes = np.array(codes).astype(np.int64) | |
| sequence = codes | |
| return sequence.flatten() | |
| def BPT_serialize(mesh: trimesh.Trimesh): | |
| # serialize mesh with BPT | |
| # 1. patchify faces into patches | |
| sequence = patchified_mesh(mesh, special_token=-2) | |
| # 2. convert coordinates to block-wise indexes | |
| codes = get_block_representation( | |
| sequence, block_size=8, offset_size=16, | |
| block_compressed=True, special_token=-2, use_special_block=True | |
| ) | |
| return codes | |
| def decode_block(sequence, compressed=True, block_size=8, offset_size=16): | |
| # decode from compressed representation | |
| if compressed: | |
| res = [] | |
| res_block = 0 | |
| for token_id in range(len(sequence)): | |
| if block_size**3 + offset_size**3 > sequence[token_id] >= block_size**3: | |
| res.append([res_block, sequence[token_id]]) | |
| elif block_size**3 > sequence[token_id] >= 0: | |
| res_block = sequence[token_id] | |
| else: | |
| print('[Warning] too large offset idx!', token_id, sequence[token_id]) | |
| sequence = np.array(res) | |
| block_id, offset_id = np.array_split(sequence, 2, axis=-1) | |
| # from hash representation to xyz | |
| coords = [] | |
| offset_id -= block_size**3 | |
| for i in [2, 1, 0]: | |
| axis = (block_id // block_size**i) * offset_size + (offset_id // offset_size**i) | |
| block_id %= block_size**i | |
| offset_id %= offset_size**i | |
| coords.append(axis) | |
| coords = np.concatenate(coords, axis=-1) # (nf 3) | |
| # back to continuous space | |
| coords = undiscretize(coords) | |
| return coords | |
| def BPT_deserialize(sequence, block_size=8, offset_size=16, compressed=True, special_token=-2, use_special_block=True): | |
| # decode codes back to coordinates | |
| special_block_base = block_size**3 + offset_size**3 | |
| start_idx = 0 | |
| vertices = [] | |
| for i in range(len(sequence)): | |
| sub_seq = [] | |
| if not use_special_block and (sequence[i] == special_token or i == len(sequence) - 1): | |
| sub_seq = sequence[start_idx:i] | |
| sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size) | |
| start_idx = i + 1 | |
| elif use_special_block and \ | |
| (special_block_base <= sequence[i] < special_block_base + block_size**3 or i == len(sequence)-1): | |
| if i != 0: | |
| sub_seq = sequence[start_idx:i] if i != len(sequence) - 1 else sequence[start_idx: i+1] | |
| if special_block_base <= sub_seq[0] < special_block_base + block_size**3: | |
| sub_seq[0] -= special_block_base | |
| sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size) | |
| start_idx = i | |
| if len(sub_seq): | |
| center, sub_seq = sub_seq[0], sub_seq[1:] | |
| for j in range(len(sub_seq) - 1): | |
| vertices.extend([center.reshape(1, 3), sub_seq[j].reshape(1, 3), sub_seq[j+1].reshape(1, 3)]) | |
| # (nf, 3) | |
| return np.concatenate(vertices, axis=0) | |
| if __name__ == '__main__': | |
| # a simple demo for serialize and deserialize mesh with bpt | |
| from data_utils import load_process_mesh, to_mesh | |
| import torch | |
| mesh = load_process_mesh('/path/to/your/mesh', quantization_bits=7) | |
| mesh['faces'] = np.array(mesh['faces']) | |
| mesh = to_mesh(mesh['vertices'], mesh['faces'], transpose=True) | |
| mesh.export('gt.obj') | |
| codes = BPT_serialize(mesh) | |
| coordinates = BPT_deserialize(codes) | |
| faces = torch.arange(1, len(coordinates) + 1).view(-1, 3) | |
| mesh = to_mesh(coordinates, faces, transpose=False, post_process=False) | |
| mesh.export('reconstructed.obj') | |