from typing import Union, Tuple, List import numpy as np import torch from skimage import measure class MeshExtractResult: def __init__(self, verts, faces, vertex_attrs=None, res=64): self.verts = verts self.faces = faces.long() self.vertex_attrs = vertex_attrs self.face_normal = self.comput_face_normals() self.vert_normal = self.comput_v_normals() self.res = res self.success = verts.shape[0] != 0 and faces.shape[0] != 0 # training only self.tsdf_v = None self.tsdf_s = None self.reg_loss = None def comput_face_normals(self): i0 = self.faces[..., 0].long() i1 = self.faces[..., 1].long() i2 = self.faces[..., 2].long() v0 = self.verts[i0, :] v1 = self.verts[i1, :] v2 = self.verts[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) face_normals = torch.nn.functional.normalize(face_normals, dim=1) return face_normals[:, None, :].repeat(1, 3, 1) def comput_v_normals(self): i0 = self.faces[..., 0].long() i1 = self.faces[..., 1].long() i2 = self.faces[..., 2].long() v0 = self.verts[i0, :] v1 = self.verts[i1, :] v2 = self.verts[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) v_normals = torch.zeros_like(self.verts) v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) v_normals = torch.nn.functional.normalize(v_normals, dim=1) return v_normals def center_vertices(vertices): """Translate the vertices so that bounding box is centered at zero.""" vert_min = vertices.min(dim=0)[0] vert_max = vertices.max(dim=0)[0] vert_center = 0.5 * (vert_min + vert_max) return vertices - vert_center class SurfaceExtractor: def _compute_box_stat( self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int ): if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) bbox_size = bbox_max - bbox_min grid_size = [ int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1, ] return grid_size, bbox_min, bbox_size def run(self, *args, **kwargs): return NotImplementedError def __call__(self, grid_logits, **kwargs): outputs = [] for i in range(grid_logits.shape[0]): try: verts, faces = self.run(grid_logits[i], **kwargs) outputs.append( MeshExtractResult( verts=verts.float(), faces=faces, res=kwargs["octree_resolution"], ) ) except Exception: import traceback traceback.print_exc() outputs.append(None) return outputs class MCSurfaceExtractor(SurfaceExtractor): def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): verts, faces, normals, _ = measure.marching_cubes( grid_logit.float().cpu().numpy(), mc_level, method="lewiner" ) grid_size, bbox_min, bbox_size = self._compute_box_stat( bounds, octree_resolution ) verts = verts / grid_size * bbox_size + bbox_min verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32) faces = torch.tensor( np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long ) faces = faces[:, [2, 1, 0]] return verts, faces class DMCSurfaceExtractor(SurfaceExtractor): def run(self, grid_logit, *, octree_resolution, **kwargs): device = grid_logit.device if not hasattr(self, "dmc"): try: from diso import DiffDMC except: raise ImportError( "Please install diso via `pip install diso`, or set mc_algo to 'mc'" ) self.dmc = DiffDMC(dtype=torch.float32).to(device) sdf = -grid_logit / octree_resolution sdf = sdf.to(torch.float32).contiguous() verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True) grid_size, bbox_min, bbox_size = self._compute_box_stat( kwargs["bounds"], octree_resolution ) verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"] return verts, faces