Spaces:
Paused
Paused
| from typing import Union, Tuple, List | |
| import numpy as np | |
| import torch | |
| from skimage import measure | |
| class MeshExtractResult: | |
| def __init__(self, verts, faces, vertex_attrs=None, res=64): | |
| self.verts = verts | |
| self.faces = faces.long() | |
| self.vertex_attrs = vertex_attrs | |
| self.face_normal = self.comput_face_normals() | |
| self.vert_normal = self.comput_v_normals() | |
| self.res = res | |
| self.success = verts.shape[0] != 0 and faces.shape[0] != 0 | |
| # training only | |
| self.tsdf_v = None | |
| self.tsdf_s = None | |
| self.reg_loss = None | |
| def comput_face_normals(self): | |
| i0 = self.faces[..., 0].long() | |
| i1 = self.faces[..., 1].long() | |
| i2 = self.faces[..., 2].long() | |
| v0 = self.verts[i0, :] | |
| v1 = self.verts[i1, :] | |
| v2 = self.verts[i2, :] | |
| face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) | |
| face_normals = torch.nn.functional.normalize(face_normals, dim=1) | |
| return face_normals[:, None, :].repeat(1, 3, 1) | |
| def comput_v_normals(self): | |
| i0 = self.faces[..., 0].long() | |
| i1 = self.faces[..., 1].long() | |
| i2 = self.faces[..., 2].long() | |
| v0 = self.verts[i0, :] | |
| v1 = self.verts[i1, :] | |
| v2 = self.verts[i2, :] | |
| face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) | |
| v_normals = torch.zeros_like(self.verts) | |
| v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) | |
| v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) | |
| v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) | |
| v_normals = torch.nn.functional.normalize(v_normals, dim=1) | |
| return v_normals | |
| def center_vertices(vertices): | |
| """Translate the vertices so that bounding box is centered at zero.""" | |
| vert_min = vertices.min(dim=0)[0] | |
| vert_max = vertices.max(dim=0)[0] | |
| vert_center = 0.5 * (vert_min + vert_max) | |
| return vertices - vert_center | |
| class SurfaceExtractor: | |
| def _compute_box_stat( | |
| self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int | |
| ): | |
| if isinstance(bounds, float): | |
| bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] | |
| bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) | |
| bbox_size = bbox_max - bbox_min | |
| grid_size = [ | |
| int(octree_resolution) + 1, | |
| int(octree_resolution) + 1, | |
| int(octree_resolution) + 1, | |
| ] | |
| return grid_size, bbox_min, bbox_size | |
| def run(self, *args, **kwargs): | |
| return NotImplementedError | |
| def __call__(self, grid_logits, **kwargs): | |
| outputs = [] | |
| for i in range(grid_logits.shape[0]): | |
| try: | |
| verts, faces = self.run(grid_logits[i], **kwargs) | |
| outputs.append( | |
| MeshExtractResult( | |
| verts=verts.float(), | |
| faces=faces, | |
| res=kwargs["octree_resolution"], | |
| ) | |
| ) | |
| except Exception: | |
| import traceback | |
| traceback.print_exc() | |
| outputs.append(None) | |
| return outputs | |
| class MCSurfaceExtractor(SurfaceExtractor): | |
| def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): | |
| verts, faces, normals, _ = measure.marching_cubes( | |
| grid_logit.float().cpu().numpy(), mc_level, method="lewiner" | |
| ) | |
| grid_size, bbox_min, bbox_size = self._compute_box_stat( | |
| bounds, octree_resolution | |
| ) | |
| verts = verts / grid_size * bbox_size + bbox_min | |
| verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32) | |
| faces = torch.tensor( | |
| np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long | |
| ) | |
| faces = faces[:, [2, 1, 0]] | |
| return verts, faces | |
| class DMCSurfaceExtractor(SurfaceExtractor): | |
| def run(self, grid_logit, *, octree_resolution, **kwargs): | |
| device = grid_logit.device | |
| if not hasattr(self, "dmc"): | |
| try: | |
| from diso import DiffDMC | |
| except: | |
| raise ImportError( | |
| "Please install diso via `pip install diso`, or set mc_algo to 'mc'" | |
| ) | |
| self.dmc = DiffDMC(dtype=torch.float32).to(device) | |
| sdf = -grid_logit / octree_resolution | |
| sdf = sdf.to(torch.float32).contiguous() | |
| verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True) | |
| grid_size, bbox_min, bbox_size = self._compute_box_stat( | |
| kwargs["bounds"], octree_resolution | |
| ) | |
| verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"] | |
| return verts, faces | |