import collections import itertools import math from typing import Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import nvdiffrast.torch as dr import torch import torch.nn as nn import torch.nn.functional as F import xatlas from diffusers import ConfigMixin, ModelMixin from transformers import PreTrainedModel, ViTConfig, ViTImageProcessor from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.pytorch_utils import ( find_pruneable_heads_and_indices, prune_linear_layer, ) def generate_planes(): """ Defines planes by the three vectors that form the "axes" of the plane. Should work with arbitrary number of planes and planes of arbitrary orientation. Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 """ return torch.tensor( [ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]], ], dtype=torch.float32, ) def project_onto_planes(planes, coordinates): """ Does a projection of a 3D point onto a batch of 2D planes, returning 2D plane coordinates. Takes plane axes of shape n_planes, 3, 3 # Takes coordinates of shape N, M, 3 # returns projections of shape N*n_planes, M, 2 """ N, M, C = coordinates.shape n_planes, _, _ = planes.shape coordinates = ( coordinates.unsqueeze(1) .expand(-1, n_planes, -1, -1) .reshape(N * n_planes, M, 3) ) inv_planes = ( torch.linalg.inv(planes) .unsqueeze(0) .expand(N, -1, -1, -1) .reshape(N * n_planes, 3, 3) ) projections = torch.bmm(coordinates, inv_planes) return projections[..., :2] def sample_from_planes( plane_axes, plane_features, coordinates, mode="bilinear", padding_mode="zeros", box_warp=None, ): assert padding_mode == "zeros" N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape plane_features = plane_features.view(N * n_planes, C, H, W) dtype = plane_features.dtype coordinates = (2 / box_warp) * coordinates # add specific box bounds projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) output_features = ( torch.nn.functional.grid_sample( plane_features, projected_coordinates.to(dtype), mode=mode, padding_mode=padding_mode, align_corners=False, ) .permute(0, 3, 2, 1) .reshape(N, n_planes, M, C) ) return output_features class OSGDecoder(nn.Module): """ Triplane decoder that gives RGB and sigma values from sampled features. Using ReLU here instead of Softplus in the original implementation. Reference: EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 """ def __init__( self, n_features: int, hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU, ): super().__init__() self.net_sdf = nn.Sequential( nn.Linear(3 * n_features, hidden_dim), activation(), *itertools.chain( *[ [ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2) ] ), nn.Linear(hidden_dim, 1), ) self.net_rgb = nn.Sequential( nn.Linear(3 * n_features, hidden_dim), activation(), *itertools.chain( *[ [ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2) ] ), nn.Linear(hidden_dim, 3), ) self.net_deformation = nn.Sequential( nn.Linear(3 * n_features, hidden_dim), activation(), *itertools.chain( *[ [ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2) ] ), nn.Linear(hidden_dim, 3), ) self.net_weight = nn.Sequential( nn.Linear(8 * 3 * n_features, hidden_dim), activation(), *itertools.chain( *[ [ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2) ] ), nn.Linear(hidden_dim, 21), ) # init all bias to zero for m in self.modules(): if isinstance(m, nn.Linear): nn.init.zeros_(m.bias) def get_geometry_prediction(self, sampled_features, flexicubes_indices): _N, n_planes, _M, _C = sampled_features.shape sampled_features = sampled_features.permute(0, 2, 1, 3).reshape( _N, _M, n_planes * _C ) sdf = self.net_sdf(sampled_features) deformation = self.net_deformation(sampled_features) grid_features = torch.index_select( input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1 ) grid_features = grid_features.reshape( sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1], ) weight = self.net_weight(grid_features) * 0.1 return sdf, deformation, weight def get_texture_prediction(self, sampled_features): _N, n_planes, _M, _C = sampled_features.shape sampled_features = sampled_features.permute(0, 2, 1, 3).reshape( _N, _M, n_planes * _C ) rgb = self.net_rgb(sampled_features) rgb = ( torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 ) # Uses sigmoid clamping from MipNeRF return rgb class TriplaneSynthesizer(nn.Module): """ Synthesizer that renders a triplane volume with planes and a camera. Reference: EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 """ DEFAULT_RENDERING_KWARGS = { "ray_start": "auto", "ray_end": "auto", "box_warp": 2.0, "white_back": True, "disparity_space_sampling": False, "clamp_mode": "softplus", "sampler_bbox_min": -1.0, "sampler_bbox_max": 1.0, } def __init__(self, triplane_dim: int, samples_per_ray: int): super().__init__() # attributes self.triplane_dim = triplane_dim self.rendering_kwargs = { **self.DEFAULT_RENDERING_KWARGS, "depth_resolution": samples_per_ray // 2, "depth_resolution_importance": samples_per_ray // 2, } # modules self.plane_axes = generate_planes() self.decoder = OSGDecoder(n_features=triplane_dim) def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): plane_axes = self.plane_axes.to(planes.device) sampled_features = sample_from_planes( plane_axes, planes, sample_coordinates, padding_mode="zeros", box_warp=self.rendering_kwargs["box_warp"], ) sdf, deformation, weight = self.decoder.get_geometry_prediction( sampled_features, flexicubes_indices ) return sdf, deformation, weight def get_texture_prediction(self, planes, sample_coordinates): plane_axes = self.plane_axes.to(planes.device) sampled_features = sample_from_planes( plane_axes, planes, sample_coordinates, padding_mode="zeros", box_warp=self.rendering_kwargs["box_warp"], ) rgb = self.decoder.get_texture_prediction(sampled_features) return rgb dmc_table = [ [ [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], ], [ [0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], ], [ [1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], [ [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], ], ] num_vd_table = [ 0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, ] check_table = [ [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 194], [1, -1, 0, 0, 193], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 164], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, -1, 0, 161], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 152], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 145], [1, 0, 0, 1, 144], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, -1, 137], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 133], [1, 0, 1, 0, 132], [1, 1, 0, 0, 131], [1, 1, 0, 0, 130], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 100], [0, 0, 0, 0, 0], [1, 0, 0, 1, 98], [0, 0, 0, 0, 0], [1, 0, 0, 1, 96], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 88], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, -1, 0, 82], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 74], [0, 0, 0, 0, 0], [1, 0, 1, 0, 72], [0, 0, 0, 0, 0], [1, 0, 0, -1, 70], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 67], [0, 0, 0, 0, 0], [1, -1, 0, 0, 65], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 56], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 52], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 44], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 40], [0, 0, 0, 0, 0], [1, 0, 0, -1, 38], [1, 0, -1, 0, 37], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, -1, 0, 33], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 28], [0, 0, 0, 0, 0], [1, 0, -1, 0, 26], [1, 0, 0, -1, 25], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 20], [0, 0, 0, 0, 0], [1, 0, -1, 0, 18], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, -1, 9], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, -1, 6], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], ] tet_table = [ [-1, -1, -1, -1, -1, -1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [4, 4, 4, 4, 4, 4], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, -1], [1, 1, 1, 1, 1, 1], [4, 4, 4, 4, 4, 4], [0, 4, 0, 4, 4, -1], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [5, 5, 5, 5, 5, 5], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, -1, 0, 2], [1, 1, 1, 1, 1, 1], [2, -1, 2, 4, 4, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 4, 4, 2], [1, 1, 1, 1, 1, 1], [2, 4, 2, 4, 4, 2], [0, 4, 0, 4, 4, 0], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [2, 5, 2, 5, 5, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 1, 1, -1, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [4, 1, 1, 4, 4, 1], [0, 1, 1, 0, 0, 1], [4, 0, 0, 4, 4, 0], [2, 2, 2, 2, 2, 2], [-1, 1, 1, 4, 4, 1], [0, 1, 1, 4, 4, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [5, 1, 1, 5, 5, 1], [0, 1, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [8, 8, 8, 8, 8, 8], [1, 1, 1, 4, 4, 1], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 4, 4, 1], [0, 4, 0, 4, 4, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 5, 5, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [6, -1, 0, 6, 0, 6], [6, 0, 0, 6, 0, 6], [6, 1, 1, 6, 1, 6], [4, 4, 4, 4, 4, 4], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, 4], [1, 1, 1, 1, 1, 1], [6, 4, -1, 6, 4, 6], [6, 4, 0, 6, 4, 6], [6, 0, 0, 6, 0, 6], [6, 1, 1, 6, 1, 6], [5, 5, 5, 5, 5, 5], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 2, 0, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [2, 4, 2, 2, 4, 2], [0, 4, 0, 4, 4, 0], [2, 0, 2, 2, 0, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [6, 1, 1, 6, -1, 6], [6, 1, 1, 6, 0, 6], [6, 0, 0, 6, 0, 6], [6, 2, 2, 6, 2, 6], [4, 1, 1, 4, 4, 1], [0, 1, 1, 0, 0, 1], [4, 0, 0, 4, 4, 4], [2, 2, 2, 2, 2, 2], [6, 1, 1, 6, 4, 6], [6, 1, 1, 6, 4, 6], [6, 0, 0, 6, 0, 6], [6, 2, 2, 6, 2, 6], [5, 1, 1, 5, 5, 1], [0, 1, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [6, 6, 6, 6, 6, 6], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 4, 1], [0, 4, 0, 4, 4, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 5, 0, 5, 0, 5], [5, 5, 5, 5, 5, 5], [5, 5, 5, 5, 5, 5], [0, 5, 0, 5, 0, 5], [-1, 5, 0, 5, 0, 5], [1, 5, 1, 5, 1, 5], [4, 5, -1, 5, 4, 5], [0, 5, 0, 5, 0, 5], [4, 5, 0, 5, 4, 5], [1, 5, 1, 5, 1, 5], [4, 4, 4, 4, 4, 4], [0, 4, 0, 4, 4, 4], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [6, 6, 6, 6, 6, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 5, 2, 5, -1, 5], [0, 5, 0, 5, 0, 5], [2, 5, 2, 5, 0, 5], [1, 5, 1, 5, 1, 5], [2, 5, 2, 5, 4, 5], [0, 5, 0, 5, 0, 5], [2, 5, 2, 5, 4, 5], [1, 5, 1, 5, 1, 5], [2, 4, 2, 4, 4, 2], [0, 4, 0, 4, 4, 4], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [2, 6, 2, 6, 6, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 1, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [4, 1, 1, 1, 4, 1], [0, 1, 1, 1, 0, 1], [4, 0, 0, 4, 4, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 4, 1], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [6, 0, 0, 6, 0, 6], [0, 0, 0, 0, 0, 0], [6, 6, 6, 6, 6, 6], [5, 5, 5, 5, 5, 5], [5, 5, 0, 5, 0, 5], [5, 5, 0, 5, 0, 5], [5, 5, 1, 5, 1, 5], [4, 4, 4, 4, 4, 4], [0, 0, 0, 0, 0, 0], [4, 4, 0, 4, 4, 4], [1, 1, 1, 1, 1, 1], [4, 4, 4, 4, 4, 4], [4, 4, 0, 4, 4, 4], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [8, 8, 8, 8, 8, 8], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 0, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 1, 1, 4, 4, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [2, 4, 2, 4, 4, 2], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [12, 12, 12, 12, 12, 12], ] class FlexiCubes: def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): self.device = device self.dmc_table = torch.tensor( dmc_table, dtype=torch.long, device=device, requires_grad=False ) self.num_vd_table = torch.tensor( num_vd_table, dtype=torch.long, device=device, requires_grad=False ) self.check_table = torch.tensor( check_table, dtype=torch.long, device=device, requires_grad=False ) self.tet_table = torch.tensor( tet_table, dtype=torch.long, device=device, requires_grad=False ) self.quad_split_1 = torch.tensor( [0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False ) self.quad_split_2 = torch.tensor( [0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False ) self.quad_split_train = torch.tensor( [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False, ) self.cube_corners = torch.tensor( [ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], ], dtype=torch.float, device=device, ) self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) self.cube_edges = torch.tensor( [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False, ) self.edge_dir_table = torch.tensor( [0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], dtype=torch.long, device=device ) self.dir_faces_table = torch.tensor( [ [[5, 4], [3, 2], [4, 5], [2, 3]], [[5, 4], [1, 0], [4, 5], [0, 1]], [[3, 2], [1, 0], [2, 3], [0, 1]], ], dtype=torch.long, device=device, ) self.adj_pairs = torch.tensor( [0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device ) self.qef_reg_scale = qef_reg_scale self.weight_scale = weight_scale def construct_voxel_grid(self, res): """ Generates a voxel grid based on the specified resolution. Args: res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it is used for all three dimensions. If a list or tuple of 3 integers is provided, they define the resolution for the x, y, and z dimensions respectively. Returns: (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the cube corners (index into vertices) of the constructed voxel grid. The vertices are centered at the origin, with the length of each dimension in the grid being one. """ base_cube_f = torch.arange(8).to(self.device) if isinstance(res, int): res = (res, res, res) voxel_grid_template = torch.ones(res, device=self.device) res = torch.tensor([res], dtype=torch.float, device=self.device) coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape( -1, 3 ) cubes = ( base_cube_f.unsqueeze(0) + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8 ).reshape(-1) verts_rounded = torch.round(verts * 10**5) / (10**5) verts_unique, inverse_indices = torch.unique( verts_rounded, dim=0, return_inverse=True ) cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) return verts_unique - 0.5, cubes def __call__( self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, gamma_f=None, training=False, output_tetmesh=False, grad_func=None, ): r""" Main function for mesh extraction from scalar field using FlexiCubes. This function converts discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, to triangle or tetrahedral meshes using a differentiable operation as described in `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances mesh quality and geometric fidelity by adjusting the surface representation based on gradient optimization. The output surface is differentiable with respect to the input vertex positions, scalar field values, and weight parameters. If you intend to extract a surface mesh from a fixed Signed Distance Field without the optimization of parameters, it is suggested to provide the "grad_func" which should return the surface gradient at any given 3D position. When grad_func is provided, the process to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. Please note, this approach is non-differentiable. For more details and example usage in optimization, refer to the `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. Args: x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values denote that the corresponding vertex resides inside the isosurface. This affects the directions of the extracted triangle faces and volume to be tetrahedralized. cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it is used for all three dimensions. If a list or tuple of 3 integers is provided, they specify the resolution for the x, y, and z dimensions respectively. beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual vertices positioning. Defaults to uniform value for all edges. alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual vertices positioning. Defaults to uniform value for all vertices. gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of quadrilaterals into triangles. Defaults to uniform value for all cubes. training (bool, optional): If set to True, applies differentiable quad splitting for training. Defaults to False. output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, outputs a triangular mesh. Defaults to False. grad_func (callable, optional): A function to compute the surface gradient at specified 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. Returns: (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: - Vertices for the extracted triangular/tetrahedral mesh. - Faces for the extracted triangular/tetrahedral mesh. - Regularizer L_dev, computed per dual vertex. .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: https://research.nvidia.com/labs/toronto-ai/flexicubes/ .. _Manifold Dual Contouring: https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf """ surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) if surf_cubes.sum() == 0: return ( torch.zeros((0, 3), device=self.device), ( torch.zeros((0, 4), dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros((0, 3), dtype=torch.long, device=self.device) ), torch.zeros((0), device=self.device), ) beta_fx12, alpha_fx8, gamma_f = self._normalize_weights( beta_fx12, alpha_fx8, gamma_f, surf_cubes ) case_ids = self._get_case_id(occ_fx8, surf_cubes, res) surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( s_n, cube_fx8, surf_cubes ) vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func, ) vertices, faces, s_edges, edge_indices = self._triangulate( s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func, ) if not output_tetmesh: return vertices, faces, L_dev else: vertices, tets = self._tetrahedralize( x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, surf_cubes, training, ) return vertices, tets, L_dev def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): """ Regularizer L_dev as in Equation 8 """ dist = torch.norm( ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1 ) mean_l2 = torch.zeros_like(vd[:, 0]) mean_l2 = (mean_l2).index_add_( 0, edge_group_to_vd, dist ) / vd_num_edges.squeeze(1).float() mad = ( dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0) ).abs() return mad def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): """ Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. """ n_cubes = surf_cubes.shape[0] if beta_fx12 is not None: beta_fx12 = torch.tanh(beta_fx12) * self.weight_scale + 1 else: beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) if alpha_fx8 is not None: alpha_fx8 = torch.tanh(alpha_fx8) * self.weight_scale + 1 else: alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) if gamma_f is not None: gamma_f = ( torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale) / 2 ) else: gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] @torch.no_grad() def _get_case_id(self, occ_fx8, surf_cubes, res): """ Obtains the ID of topology cases based on cell corner occupancy. This function resolves the ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the supplementary material. It should be noted that this function assumes a regular grid. """ case_ids = ( occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0) ).sum(-1) problem_config = self.check_table.to(self.device)[case_ids] to_check = problem_config[..., 0] == 1 problem_config = problem_config[to_check] if not isinstance(res, (list, tuple)): res = [res, res, res] # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). # This allows efficient checking on adjacent cubes. problem_config_full = torch.zeros( list(res) + [5], device=self.device, dtype=torch.long ) vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 vol_idx_problem = vol_idx[surf_cubes][to_check] problem_config_full[ vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2] ] = problem_config vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] within_range = ( (vol_idx_problem_adj[..., 0] >= 0) & (vol_idx_problem_adj[..., 0] < res[0]) & (vol_idx_problem_adj[..., 1] >= 0) & (vol_idx_problem_adj[..., 1] < res[1]) & (vol_idx_problem_adj[..., 2] >= 0) & (vol_idx_problem_adj[..., 2] < res[2]) ) vol_idx_problem = vol_idx_problem[within_range] vol_idx_problem_adj = vol_idx_problem_adj[within_range] problem_config = problem_config[within_range] problem_config_adj = problem_config_full[ vol_idx_problem_adj[..., 0], vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2], ] # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. to_invert = problem_config_adj[..., 0] == 1 idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][ within_range ][to_invert] case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) return case_ids @torch.no_grad() def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): """ Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge and marks the cube edges with this index. """ occ_n = s_n < 0 all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) unique_edges, _idx_map, counts = torch.unique( all_edges, dim=0, return_inverse=True, return_counts=True ) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 surf_edges_mask = mask_edges[_idx_map] counts = counts[_idx_map] mapping = ( torch.ones( (unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device ) * -1 ) mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. idx_map = mapping[_idx_map] surf_edges = unique_edges[mask_edges] return surf_edges, idx_map, counts, surf_edges_mask @torch.no_grad() def _identify_surf_cubes(self, s_n, cube_fx8): """ Identifies grid cubes that intersect with the underlying surface by checking if the signs at all corners are not identical. """ occ_n = s_n < 0 occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) _occ_sum = torch.sum(occ_fx8, -1) surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) return surf_cubes, occ_fx8 def _linear_interp(self, edges_weight, edges_x): """ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. """ edge_dim = edges_weight.dim() - 2 assert edges_weight.shape[edge_dim] == 2 edges_weight = torch.cat( [ torch.index_select( input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim, ), -torch.index_select( input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim, ), ], edge_dim, ) denominator = edges_weight.sum(edge_dim) ue = (edges_x * edges_weight).sum(edge_dim) / denominator return ue def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) c_bx3 = c_bx3.reshape(-1, 3) A = norm_bxnx3 B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) A_reg = ( (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale) .unsqueeze(0) .repeat(p_bxnx3.shape[0], 1, 1) ) B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) A = torch.cat([A, A_reg], 1) B = torch.cat([B, B_reg], 1) dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) return dual_verts def _compute_vd( self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func, ): """ Computes the location of dual vertices as described in Section 4.2 """ alpha_nx12x2 = torch.index_select( input=alpha_fx8, index=self.cube_edges, dim=1 ).reshape(-1, 12, 2) surf_edges_x = torch.index_select( input=x_nx3, index=surf_edges.reshape(-1), dim=0 ).reshape(-1, 2, 3) surf_edges_s = torch.index_select( input=s_n, index=surf_edges.reshape(-1), dim=0 ).reshape(-1, 2, 1) zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) idx_map = idx_map.reshape(-1, 12) num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = ( [], [], [], [], [], ) total_num_vd = 0 vd_idx_map = torch.zeros( (case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False, ) if grad_func is not None: normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) vd = [] for num in torch.unique(num_vd): cur_cubes = ( num_vd == num ) # consider cubes with the same numbers of vd emitted (for batching) curr_num_vd = cur_cubes.sum() * num curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape( -1, num * 7 ) curr_edge_group_to_vd = ( torch.arange(curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd ) total_num_vd += curr_num_vd curr_edge_group_to_cube = ( torch.arange(idx_map.shape[0], device=self.device)[cur_cubes] .unsqueeze(-1) .repeat(1, num * 7) .reshape_as(curr_edge_group) ) curr_mask = curr_edge_group != -1 edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) edge_group_to_vd.append( torch.masked_select( curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask ) ) edge_group_to_cube.append( torch.masked_select(curr_edge_group_to_cube, curr_mask) ) vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) vd_gamma.append( torch.masked_select(gamma_f, cur_cubes) .unsqueeze(-1) .repeat(1, num) .reshape(-1) ) if grad_func is not None: with torch.no_grad(): cube_e_verts_idx = idx_map[cur_cubes] curr_edge_group[~curr_mask] = 0 verts_group_idx = torch.gather( input=cube_e_verts_idx, dim=1, index=curr_edge_group ) verts_group_idx[verts_group_idx == -1] = 0 verts_group_pos = torch.index_select( input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0 ).reshape(-1, num.item(), 7, 3) v0 = ( x_nx3[surf_cubes_fx8[cur_cubes][:, 0]] .reshape(-1, 1, 1, 3) .repeat(1, num.item(), 1, 1) ) curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) verts_centroid = (verts_group_pos * curr_mask).sum(2) / ( curr_mask.sum(2) ) normals_bx7x3 = torch.index_select( input=normals, index=verts_group_idx.reshape(-1), dim=0 ).reshape(-1, num.item(), 7, 3) curr_mask = curr_mask.squeeze(2) vd.append( self._solve_vd_QEF( (verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, verts_centroid - v0.squeeze(2), ) + v0.reshape(-1, 3) ) edge_group = torch.cat(edge_group) edge_group_to_vd = torch.cat(edge_group_to_vd) edge_group_to_cube = torch.cat(edge_group_to_cube) vd_num_edges = torch.cat(vd_num_edges) vd_gamma = torch.cat(vd_gamma) if grad_func is not None: vd = torch.cat(vd) L_dev = torch.zeros([1], device=self.device) else: vd = torch.zeros((total_num_vd, 3), device=self.device) beta_sum = torch.zeros((total_num_vd, 1), device=self.device) idx_group = torch.gather( input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group, ) x_group = torch.index_select( input=surf_edges_x, index=idx_group.reshape(-1), dim=0 ).reshape(-1, 2, 3) s_group = torch.index_select( input=surf_edges_s, index=idx_group.reshape(-1), dim=0 ).reshape(-1, 2, 1) zero_crossing_group = torch.index_select( input=zero_crossing, index=idx_group.reshape(-1), dim=0 ).reshape(-1, 3) alpha_group = torch.index_select( input=alpha_nx12x2.reshape(-1, 2), dim=0, index=edge_group_to_cube * 12 + edge_group, ).reshape(-1, 2, 1) ue_group = self._linear_interp(s_group * alpha_group, x_group) beta_group = torch.gather( input=beta_fx12.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group, ).reshape(-1, 1) beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) vd = ( vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum ) L_dev = self._compute_reg_loss( vd, zero_crossing_group, edge_group_to_vd, vd_num_edges ) v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd vd_idx_map = (vd_idx_map.reshape(-1)).scatter( dim=0, index=edge_group_to_cube * 12 + edge_group, src=v_idx[edge_group_to_vd], ) return vd, L_dev, vd_gamma, vd_idx_map def _triangulate( self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func, ): """ Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into triangles based on the gamma parameter, as described in Section 4.3. """ with torch.no_grad(): group_mask = ( edge_counts == 4 ) & surf_edges_mask # surface edges shared by 4 cubes. group = idx_map.reshape(-1)[group_mask] vd_idx = vd_idx_map[group_mask] edge_indices, indices = torch.sort(group, stable=True) quad_vd_idx = vd_idx[indices].reshape(-1, 4) # Ensure all face directions point towards the positive SDF to maintain consistent winding. s_edges = s_n[ surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1) ].reshape(-1, 2) flip_mask = s_edges[:, 0] > 0 quad_vd_idx = torch.cat( ( quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]], ) ) if grad_func is not None: # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. with torch.no_grad(): vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) quad_gamma = torch.index_select( input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0 ).reshape(-1, 4, 3) gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) else: quad_gamma = torch.index_select( input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0 ).reshape(-1, 4) gamma_02 = torch.index_select( input=quad_gamma, index=torch.tensor(0, device=self.device), dim=1 ) * torch.index_select( input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1 ) gamma_13 = torch.index_select( input=quad_gamma, index=torch.tensor(1, device=self.device), dim=1 ) * torch.index_select( input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1 ) if not training: mask = (gamma_02 > gamma_13).squeeze(1) faces = torch.zeros( (quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device ) faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] faces = faces.reshape(-1, 3) else: vd_quad = torch.index_select( input=vd, index=quad_vd_idx.reshape(-1), dim=0 ).reshape(-1, 4, 3) vd_02 = ( torch.index_select( input=vd_quad, index=torch.tensor(0, device=self.device), dim=1 ) + torch.index_select( input=vd_quad, index=torch.tensor(2, device=self.device), dim=1 ) ) / 2 vd_13 = ( torch.index_select( input=vd_quad, index=torch.tensor(1, device=self.device), dim=1 ) + torch.index_select( input=vd_quad, index=torch.tensor(3, device=self.device), dim=1 ) ) / 2 weight_sum = (gamma_02 + gamma_13) + 1e-8 vd_center = ( (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) ).squeeze(1) vd_center_idx = ( torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] ) vd = torch.cat([vd, vd_center]) faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) faces = torch.cat( [faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1 ).reshape(-1, 3) return vd, faces, s_edges, edge_indices def _tetrahedralize( self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, surf_cubes, training, ): """ Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. """ occ_n = s_n < 0 occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) occ_sum = torch.sum(occ_fx8, -1) inside_verts = x_nx3[occ_n] mapping_inside_verts = ( torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 ) mapping_inside_verts[occ_n] = ( torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] ) """ For each grid edge connecting two grid vertices with different signs, we first form a four-sided pyramid by connecting one of the grid vertices with four mesh vertices that correspond to the grid edge and then subdivide the pyramid into two tetrahedra """ inside_verts_idx = mapping_inside_verts[ surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[s_edges < 0] ] if not training: inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) else: inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) """ For each grid edge connecting two grid vertices with the same sign, the tetrahedron is formed by the two grid vertices and two vertices in consecutive adjacent cells """ inside_cubes = occ_sum == 8 inside_cubes_center = ( x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) ) inside_cubes_center_idx = ( torch.arange(inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] ) surface_n_inside_cubes = surf_cubes | inside_cubes edge_center_vertex_idx = ( torch.ones( ((surface_n_inside_cubes).sum(), 13), dtype=torch.long, device=x_nx3.device, ) * -1 ) surf_cubes = surf_cubes[surface_n_inside_cubes] inside_cubes = inside_cubes[surface_n_inside_cubes] edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) unique_edges, _idx_map, counts = torch.unique( all_edges, dim=0, return_inverse=True, return_counts=True ) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 mask = mask_edges[_idx_map] counts = counts[_idx_map] mapping = ( torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 ) mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) idx_map = mapping[_idx_map] group_mask = (counts == 4) & mask group = idx_map.reshape(-1)[group_mask] edge_indices, indices = torch.sort(group) cube_idx = ( torch.arange( (_idx_map.shape[0] // 12), dtype=torch.long, device=self.device ) .unsqueeze(1) .expand(-1, 12) .reshape(-1)[group_mask] ) edge_idx = ( torch.arange((12), dtype=torch.long, device=self.device) .unsqueeze(0) .expand(_idx_map.shape[0] // 12, -1) .reshape(-1)[group_mask] ) # Identify the face shared by the adjacent cells. cube_idx_4 = cube_idx[indices].reshape(-1, 4) edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) # Identify an edge of the face with different signs and # select the mesh vertex corresponding to the identified edge. case_ids_expand = ( torch.ones( (surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device ) * 255 ) case_ids_expand[surf_cubes] = case_ids cases = case_ids_expand[cube_idx_4x2] quad_edge = edge_center_vertex_idx[ cube_idx_4x2, self.tet_table[cases, shared_faces_4x2] ].reshape(-1, 2) mask = (quad_edge == -1).sum(-1) == 0 inside_edge = mapping_inside_verts[ unique_edges[mask_edges][edge_indices].reshape(-1) ].reshape(-1, 2) tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] tets = torch.cat([tets_surface, tets_inside]) vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) return vertices, tets def get_center_boundary_index(grid_res, device): v = torch.zeros( (grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device ) v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True center_indices = torch.nonzero(v.reshape(-1)) v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False v[:2, ...] = True v[-2:, ...] = True v[:, :2, ...] = True v[:, -2:, ...] = True v[:, :, :2] = True v[:, :, -2:] = True boundary_indices = torch.nonzero(v.reshape(-1)) return center_indices, boundary_indices class Geometry: def __init__(self): pass def forward(self): pass class FlexiCubesGeometry(Geometry): def __init__( self, grid_res=64, scale=2.0, device="cuda", renderer=None, render_type="neural_render", args=None, ): super(FlexiCubesGeometry, self).__init__() self.grid_res = grid_res self.device = device self.args = args self.fc = FlexiCubes(device, weight_scale=0.5) self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) if isinstance(scale, list): self.verts[:, 0] = self.verts[:, 0] * scale[0] self.verts[:, 1] = self.verts[:, 1] * scale[1] self.verts[:, 2] = self.verts[:, 2] * scale[1] else: self.verts = self.verts * scale all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) self.all_edges = torch.unique(all_edges, dim=0) # Parameters used for fix boundary sdf self.center_indices, self.boundary_indices = get_center_boundary_index( self.grid_res, device ) self.renderer = renderer self.render_type = render_type def getAABB(self): return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values def get_mesh( self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False, ): if indices is None: indices = self.indices verts, faces, v_reg_loss = self.fc( v_deformed_nx3, sdf_n, indices, self.grid_res, beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], gamma_f=weight_n[:, 20], training=is_training, ) return verts, faces, v_reg_loss def render_mesh( self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False, ): return_value = dict() if self.render_type == "neural_render": tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = ( self.renderer.render_mesh( mesh_v_nx3.unsqueeze(dim=0), mesh_f_fx3.int(), camera_mv_bx4x4, mesh_v_nx3.unsqueeze(dim=0), resolution=resolution, device=self.device, hierarchical_mask=hierarchical_mask, ) ) return_value["tex_pos"] = tex_pos return_value["mask"] = mask return_value["hard_mask"] = hard_mask return_value["rast"] = rast return_value["v_pos_clip"] = v_pos_clip return_value["mask_pyramid"] = mask_pyramid return_value["depth"] = depth return_value["normal"] = normal else: raise NotImplementedError return return_value def render( self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256, ): # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 v_list = [] f_list = [] n_batch = v_deformed_bxnx3.shape[0] all_render_output = [] for i_batch in range(n_batch): verts_nx3, faces_fx3 = self.get_mesh( v_deformed_bxnx3[i_batch], sdf_bxn[i_batch] ) v_list.append(verts_nx3) f_list.append(faces_fx3) render_output = self.render_mesh( verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution ) all_render_output.append(render_output) # Concatenate all render output return_keys = all_render_output[0].keys() return_value = dict() for k in return_keys: value = [v[k] for v in all_render_output] return_value[k] = value # We can do concatenation outside of the render return return_value def interpolate(attr, rast, attr_idx, rast_db=None): return dr.interpolate( attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else "all", ) def xfm_points(points, matrix, use_python=True): """Transform points. Args: points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] use_python: Use PyTorch's torch.matmul (for validation) Returns: Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. """ out = torch.matmul( torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=1.0), torch.transpose(matrix, 1, 2), ) if torch.is_anomaly_enabled(): assert torch.all( torch.isfinite(out) ), "Output of xfm_points contains inf or NaN" return out def dot(x, y): return torch.sum(x * y, -1, keepdim=True) def compute_vertex_normal(v_pos, t_pos_idx): i0 = t_pos_idx[:, 0] i1 = t_pos_idx[:, 1] i2 = t_pos_idx[:, 2] v0 = v_pos[i0, :] v1 = v_pos[i1, :] v2 = v_pos[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) # Splat face normals to vertices v_nrm = torch.zeros_like(v_pos) v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) # Normalize, replace zero (degenerated) normals with some default value v_nrm = torch.where( dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) ) v_nrm = F.normalize(v_nrm, dim=1) assert torch.all(torch.isfinite(v_nrm)) return v_nrm class Renderer: def __init__(self): pass def forward(self): pass class NeuralRender(Renderer): def __init__(self, device="cuda", camera_model=None): super(NeuralRender, self).__init__() self.device = device self.ctx = dr.RasterizeCudaContext(device=device) self.projection_mtx = None self.camera = camera_model def render_mesh( self, mesh_v_pos_bxnx3, mesh_t_pos_idx_fx3, camera_mv_bx4x4, mesh_v_feat_bxnxd, resolution=256, spp=1, device="cuda", hierarchical_mask=False, ): assert not hierarchical_mask mtx_in = ( torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 ) v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates v_pos_clip = self.camera.project(v_pos) # Projection in the camera v_nrm = compute_vertex_normal( mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long() ) # vertex normals in world coordinates # Render the image, # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render num_layers = 1 mask_pyramid = None assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes mesh_v_feat_bxnxd = torch.cat( [mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1 ) # Concatenate the pos with dr.DepthPeeler( self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp], ) as peeler: for _ in range(num_layers): rast, db = peeler.rasterize_next_layer() gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) hard_mask = torch.clamp(rast[..., -1:], 0, 1) antialias_mask = dr.antialias( hard_mask.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3 ) depth = gb_feat[..., -2:-1] ori_mesh_feature = gb_feat[..., :-4] normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) normal = dr.antialias( normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3 ) normal = F.normalize(normal, dim=-1) normal = torch.lerp( torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float() ) # black background return ( ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, ) def projection(x=0.1, n=1.0, f=50.0, near_plane=None): if near_plane is None: near_plane = n return np.array( [ [n / x, 0, 0, 0], [0, n / -x, 0, 0], [ 0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane), ], [0, 0, -1, 0], ] ).astype(np.float32) class Camera(nn.Module): def __init__(self): super(Camera, self).__init__() pass class PerspectiveCamera(Camera): def __init__(self, fovy=49.0, device="cuda"): super(PerspectiveCamera, self).__init__() self.device = device focal = np.tan(fovy / 180.0 * np.pi * 0.5) self.proj_mtx = ( torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)) .to(self.device) .unsqueeze(dim=0) ) def project(self, points_bxnx4): out = torch.matmul(points_bxnx4, torch.transpose(self.proj_mtx, 1, 2)) return out class ViTEmbeddings(nn.Module): def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.mask_token = ( nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None ) self.patch_embeddings = ViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter( torch.randn(1, num_patches + 1, config.hidden_size) ) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int ) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. Source: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 if num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] h0 = height // self.config.patch_size w0 = width // self.config.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 h0, w0 = h0 + 0.1, w0 + 0.1 patch_pos_embed = patch_pos_embed.reshape( 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), mode="bicubic", align_corners=False, ) assert ( int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward( self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False, ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape embeddings = self.patch_embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding ) if bool_masked_pos is not None: seq_length = embeddings.shape[1] mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding( embeddings, height, width ) else: embeddings = embeddings + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings class ViTPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = ( image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) ) patch_size = ( patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) ) num_patches = (image_size[1] // patch_size[1]) * ( image_size[0] // patch_size[0] ) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=patch_size ) def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) if not interpolate_pos_encoding: if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size[0]}*{self.image_size[1]})." ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings class ViTSelfAttention(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size" ): raise ValueError( f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear( config.hidden_size, self.all_head_size, bias=config.qkv_bias ) self.key = nn.Linear( config.hidden_size, self.all_head_size, bias=config.qkv_bias ) self.value = nn.Linear( config.hidden_size, self.all_head_size, bias=config.qkv_bias ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = ( (context_layer, attention_probs) if output_attentions else (context_layer,) ) return outputs class ViTSelfOutput(nn.Module): """ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the layernorm applied before each block. """ def __init__(self, config: ViTConfig) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, hidden_states: torch.Tensor, input_tensor: torch.Tensor ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class ViTAttention(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() self.attention = ViTSelfAttention(config) self.output = ViTSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads: Set[int]) -> None: if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads, ) # Prune linear layers self.attention.query = prune_linear_layer(self.attention.query, index) self.attention.key = prune_linear_layer(self.attention.key, index) self.attention.value = prune_linear_layer(self.attention.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.attention.num_attention_heads = self.attention.num_attention_heads - len( heads ) self.attention.all_head_size = ( self.attention.attention_head_size * self.attention.num_attention_heads ) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_outputs = self.attention(hidden_states, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[ 1: ] # add attentions if we output them return outputs class ViTIntermediate(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class ViTOutput(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, hidden_states: torch.Tensor, input_tensor: torch.Tensor ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = hidden_states + input_tensor return hidden_states def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class ViTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = ViTAttention(config) self.intermediate = ViTIntermediate(config) self.output = ViTOutput(config) self.layernorm_before = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.layernorm_after = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True) ) nn.init.constant_(self.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.adaLN_modulation[-1].bias, 0) def forward( self, hidden_states: torch.Tensor, adaln_input: torch.Tensor = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation( adaln_input ).chunk(4, dim=1) self_attention_outputs = self.attention( modulate( self.layernorm_before(hidden_states), shift_msa, scale_msa ), # in ViT, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[ 1: ] # add self attentions if we output attention weights # first residual connection hidden_states = attention_output + hidden_states # in ViT, layernorm is also applied after self-attention layer_output = modulate( self.layernorm_after(hidden_states), shift_mlp, scale_mlp ) layer_output = self.intermediate(layer_output) # second residual connection is done here layer_output = self.output(layer_output, hidden_states) outputs = (layer_output,) + outputs return outputs class ViTEncoder(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() self.config = config self.layer = nn.ModuleList( [ViTLayer(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, adaln_input: torch.Tensor = None, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, adaln_input, layer_head_mask, output_attentions, ) else: layer_outputs = layer_module( hidden_states, adaln_input, layer_head_mask, output_attentions ) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class ViTPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = ViTConfig base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["ViTEmbeddings", "ViTLayer"] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues module.weight.data = nn.init.trunc_normal_( module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, ViTEmbeddings): module.position_embeddings.data = nn.init.trunc_normal_( module.position_embeddings.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.position_embeddings.dtype) module.cls_token.data = nn.init.trunc_normal_( module.cls_token.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.cls_token.dtype) class ViTModel(ViTPreTrainedModel): def __init__( self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, ): super().__init__(config) self.config = config self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) self.encoder = ViTEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pooler = ViTPooler(config) if add_pooling_layer else None # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> ViTPatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def forward( self, pixel_values: Optional[torch.Tensor] = None, adaln_input: Optional[torch.Tensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype if pixel_values.dtype != expected_dtype: pixel_values = pixel_values.to(expected_dtype) embedding_output = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding, ) encoder_outputs = self.encoder( embedding_output, adaln_input=adaln_input, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = ( self.pooler(sequence_output) if self.pooler is not None else None ) if not return_dict: head_outputs = ( (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) ) return head_outputs + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class ViTPooler(nn.Module): def __init__(self, config: ViTConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class DinoWrapper(nn.Module): def __init__(self, model_name: str, freeze: bool = True): super().__init__() self.model, self.processor = self._build_dino(model_name) self.camera_embedder = nn.Sequential( nn.Linear(16, self.model.config.hidden_size, bias=True), nn.SiLU(), nn.Linear( self.model.config.hidden_size, self.model.config.hidden_size, bias=True ), ) if freeze: self._freeze() def forward(self, image, camera): if image.ndim == 5: image = image.view(-1, *image.shape[2:]) dtype = image.dtype inputs = ( self.processor( images=image.float(), return_tensors="pt", do_rescale=False, do_resize=False, ) .to(self.model.device) .to(dtype) ) # embed camera camera_embeddings = self.camera_embedder(camera) camera_embeddings = camera_embeddings.view(-1, camera_embeddings.shape[-1]) embeddings = camera_embeddings # This resampling of positional embedding uses bicubic interpolation outputs = self.model( **inputs, adaln_input=embeddings, interpolate_pos_encoding=True ) last_hidden_states = outputs.last_hidden_state return last_hidden_states def _freeze(self): self.model.eval() for name, param in self.model.named_parameters(): param.requires_grad = False @staticmethod def _build_dino( model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5 ): import requests try: model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) processor = ViTImageProcessor.from_pretrained(model_name) return model, processor except requests.exceptions.ProxyError as err: if proxy_error_retries > 0: print( f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds..." ) import time time.sleep(proxy_error_cooldown) return DinoWrapper._build_dino( model_name, proxy_error_retries - 1, proxy_error_cooldown ) else: raise err class BasicTransformerBlock(nn.Module): def __init__( self, inner_dim: int, cond_dim: int, num_heads: int, eps: float, attn_drop: float = 0.0, attn_bias: bool = False, mlp_ratio: float = 4.0, mlp_drop: float = 0.0, ): super().__init__() self.norm1 = nn.LayerNorm(inner_dim) self.cross_attn = nn.MultiheadAttention( embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, dropout=attn_drop, bias=attn_bias, batch_first=True, ) self.norm2 = nn.LayerNorm(inner_dim) self.self_attn = nn.MultiheadAttention( embed_dim=inner_dim, num_heads=num_heads, dropout=attn_drop, bias=attn_bias, batch_first=True, ) self.norm3 = nn.LayerNorm(inner_dim) self.mlp = nn.Sequential( nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), nn.GELU(), nn.Dropout(mlp_drop), nn.Linear(int(inner_dim * mlp_ratio), inner_dim), nn.Dropout(mlp_drop), ) def forward(self, x, cond): x = x + self.cross_attn(self.norm1(x), cond, cond)[0] before_sa = self.norm2(x) x = x + self.self_attn(before_sa, before_sa, before_sa)[0] x = x + self.mlp(self.norm3(x)) return x class TriplaneTransformer(nn.Module): def __init__( self, inner_dim: int, image_feat_dim: int, triplane_low_res: int, triplane_high_res: int, triplane_dim: int, num_layers: int, num_heads: int, eps: float = 1e-6, ): super().__init__() self.triplane_low_res = triplane_low_res self.triplane_high_res = triplane_high_res self.triplane_dim = triplane_dim self.pos_embed = nn.Parameter( torch.randn(1, 3 * triplane_low_res**2, inner_dim) * (1.0 / inner_dim) ** 0.5 ) self.layers = nn.ModuleList( [ BasicTransformerBlock( inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps, ) for _ in range(num_layers) ] ) self.norm = nn.LayerNorm(inner_dim, eps=eps) self.deconv = nn.ConvTranspose2d( inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0 ) def forward(self, image_feats): N = image_feats.shape[0] H = W = self.triplane_low_res x = self.pos_embed.repeat(N, 1, 1) for layer in self.layers: x = layer(x, image_feats) x = self.norm(x) x = x.view(N, 3, H, W, -1) x = torch.einsum("nihwd->indhw", x) x = x.contiguous().view(3 * N, -1, H, W) x = self.deconv(x) x = x.view(3, N, *x.shape[-3:]) x = torch.einsum("indhw->nidhw", x) x = x.contiguous() return x def interpolate_atlas(attr, rast, attr_idx, rast_db=None): return dr.interpolate( attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else "all", ) def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): _, indices, uvs = xatlas.parametrize( mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy() ) indices_int64 = indices.astype(np.uint64, casting="same_kind").view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) uv_clip = uvs[None, ...] * 2.0 - 1.0 uv_clip4 = torch.cat( ( uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1]), ), dim=-1, ) rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) gb_pos, _ = interpolate_atlas(mesh_v[None, ...], rast, mesh_pos_idx.int()) mask = rast[..., 3:4] > 0 return uvs, mesh_tex_idx, gb_pos, mask class LRM(ModelMixin, ConfigMixin): def __init__( self, encoder_freeze: bool = False, encoder_model_name: str = "facebook/dino-vitb16", encoder_feat_dim: int = 768, transformer_dim: int = 1024, transformer_layers: int = 16, transformer_heads: int = 16, triplane_low_res: int = 32, triplane_high_res: int = 64, triplane_dim: int = 80, rendering_samples_per_ray: int = 128, grid_res: int = 128, grid_scale: float = 2.1, ): super().__init__() self.grid_res = grid_res self.grid_scale = grid_scale self.deformation_multiplier = 4.0 self.encoder = DinoWrapper( model_name=encoder_model_name, freeze=encoder_freeze, ) self.transformer = TriplaneTransformer( inner_dim=transformer_dim, num_layers=transformer_layers, num_heads=transformer_heads, image_feat_dim=encoder_feat_dim, triplane_low_res=triplane_low_res, triplane_high_res=triplane_high_res, triplane_dim=triplane_dim, ) self.synthesizer = TriplaneSynthesizer( triplane_dim=triplane_dim, samples_per_ray=rendering_samples_per_ray, ) def init_flexicubes_geometry(self, device, fovy=50.0): camera = PerspectiveCamera(fovy=fovy, device=device) renderer = NeuralRender(device, camera_model=camera) self.geometry = FlexiCubesGeometry( grid_res=self.grid_res, scale=self.grid_scale, renderer=renderer, render_type="neural_render", device=device, ) def forward_planes(self, images, cameras): B = images.shape[0] image_feats = self.encoder(images, cameras) image_feats = image_feats.view(B, -1, image_feats.shape[-1]) planes = self.transformer(image_feats) return planes def get_sdf_deformation_prediction(self, planes): init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1) sdf, deformation, weight = torch.utils.checkpoint.checkpoint( self.synthesizer.get_geometry_prediction, planes, init_position, self.geometry.indices, use_reentrant=False, ) deformation = ( 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation) ) sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32) sdf_bxnxnxn = sdf.reshape( (sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1) ) sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) if torch.sum(zero_surface).item() > 0: update_sdf = torch.zeros_like(sdf[0:1]) max_sdf = sdf.max() min_sdf = sdf.min() update_sdf[:, self.geometry.center_indices] += 1.0 - min_sdf update_sdf[:, self.geometry.boundary_indices] += -1 - max_sdf new_sdf = torch.zeros_like(sdf) for i_batch in range(zero_surface.shape[0]): if zero_surface[i_batch]: new_sdf[i_batch : i_batch + 1] += update_sdf update_mask = (new_sdf == 0).float() sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) sdf_reg_loss = sdf_reg_loss * zero_surface.float() sdf = sdf * update_mask + new_sdf * (1 - update_mask) final_sdf = [] final_def = [] for i_batch in range(zero_surface.shape[0]): if zero_surface[i_batch]: final_sdf.append(sdf[i_batch : i_batch + 1].detach()) final_def.append(deformation[i_batch : i_batch + 1].detach()) else: final_sdf.append(sdf[i_batch : i_batch + 1]) final_def.append(deformation[i_batch : i_batch + 1]) sdf = torch.cat(final_sdf, dim=0) deformation = torch.cat(final_def, dim=0) return sdf, deformation, sdf_reg_loss, weight def get_geometry_prediction(self, planes=None): sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction( planes ) v_deformed = ( self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation ) tets = self.geometry.indices n_batch = planes.shape[0] v_list = [] f_list = [] flexicubes_surface_reg_list = [] for i_batch in range(n_batch): verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( v_deformed[i_batch], sdf[i_batch].squeeze(dim=-1), with_uv=False, indices=tets, weight_n=weight[i_batch].squeeze(dim=-1), is_training=self.training, ) flexicubes_surface_reg_list.append(flexicubes_surface_reg) v_list.append(verts) f_list.append(faces) flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() flexicubes_weight_reg = (weight**2).mean() return ( v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg), ) def get_texture_prediction(self, planes, tex_pos, hard_mask=None): tex_pos = torch.cat(tex_pos, dim=0) if hard_mask is not None: tex_pos = tex_pos * hard_mask.float() batch_size = tex_pos.shape[0] tex_pos = tex_pos.reshape(batch_size, -1, 3) if hard_mask is not None: n_point_list = torch.sum( hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1 ) sample_tex_pose_list = [] max_point = n_point_list.max() expanded_hard_mask = ( hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 ) for i in range(tex_pos.shape[0]): tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) if tex_pos_one_shape.shape[1] < max_point: tex_pos_one_shape = torch.cat( [ tex_pos_one_shape, torch.zeros( 1, max_point - tex_pos_one_shape.shape[1], 3, device=tex_pos_one_shape.device, dtype=torch.float32, ), ], dim=1, ) sample_tex_pose_list.append(tex_pos_one_shape) tex_pos = torch.cat(sample_tex_pose_list, dim=0) tex_feat = torch.utils.checkpoint.checkpoint( self.synthesizer.get_texture_prediction, planes, tex_pos, use_reentrant=False, ) if hard_mask is not None: final_tex_feat = torch.zeros( planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device, ) expanded_hard_mask = ( hard_mask.reshape(hard_mask.shape[0], -1, 1).expand( -1, -1, final_tex_feat.shape[-1] ) > 0.5 ) for i in range(planes.shape[0]): final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][ : n_point_list[i] ].reshape(-1) tex_feat = final_tex_feat return tex_feat.reshape( planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1] ) def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256): return_value_list = [] for i_mesh in range(len(mesh_v)): return_value = self.geometry.render_mesh( mesh_v[i_mesh], mesh_f[i_mesh].int(), cam_mv[i_mesh], resolution=render_size, hierarchical_mask=False, ) return_value_list.append(return_value) return_keys = return_value_list[0].keys() return_value = dict() for k in return_keys: value = [v[k] for v in return_value_list] return_value[k] = value mask = torch.cat(return_value["mask"], dim=0) hard_mask = torch.cat(return_value["hard_mask"], dim=0) tex_pos = return_value["tex_pos"] depth = torch.cat(return_value["depth"], dim=0) normal = torch.cat(return_value["normal"], dim=0) return mask, hard_mask, tex_pos, depth, normal def forward_geometry(self, planes, render_cameras, render_size=256): B, NV = render_cameras.shape[:2] mesh_v, mesh_f, sdf, _, _, sdf_reg_loss = self.get_geometry_prediction(planes) cam_mv = render_cameras run_n_view = cam_mv.shape[1] antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh( mesh_v, mesh_f, cam_mv, render_size=render_size ) tex_hard_mask = hard_mask tex_pos = [ torch.cat([pos[i_view : i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos ] tex_hard_mask = torch.cat( [ torch.cat( [ tex_hard_mask[ i * run_n_view + i_view : i * run_n_view + i_view + 1 ] for i_view in range(run_n_view) ], dim=2, ) for i in range(planes.shape[0]) ], dim=0, ) tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask) background_feature = torch.ones_like(tex_feat) img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) img_feat = torch.cat( [ torch.cat( [ img_feat[ i : i + 1, :, render_size * i_view : render_size * (i_view + 1), ] for i_view in range(run_n_view) ], dim=0, ) for i in range(len(tex_pos)) ], dim=0, ) img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) out = { "img": img, "mask": antilias_mask, "depth": depth, "normal": normal, "sdf": sdf, "mesh_v": mesh_v, "mesh_f": mesh_f, "sdf_reg_loss": sdf_reg_loss, } return out def forward(self, images, cameras, render_cameras, render_size: int): planes = self.forward_planes(images, cameras) out = self.forward_geometry(planes, render_cameras, render_size=render_size) return {"planes": planes, **out} def extract_mesh( self, planes: torch.Tensor, use_texture_map: bool = False, texture_resolution: int = 1024, progress_callback: Optional[Callable[[float], None]] = None, **kwargs, ): """ Extract a 3D mesh from FlexiCubes. Only support batch_size 1. :param planes: triplane features :param use_texture_map: use texture map or vertex color :param texture_resolution: the resolution of texure map """ assert planes.shape[0] == 1 if progress_callback is not None: progress_callback(0.0) mesh_v, mesh_f, _, _, _, _ = self.get_geometry_prediction(planes) vertices, faces = mesh_v[0], mesh_f[0] if progress_callback is not None: progress_callback(0.5) if not use_texture_map: vertices_tensor = vertices.unsqueeze(0) vertices_colors = ( self.synthesizer.get_texture_prediction(planes, vertices_tensor) .clamp(0, 1) .squeeze(0) .cpu() .numpy() ) vertices_colors = (vertices_colors * 255).astype(np.uint8) if progress_callback is not None: progress_callback(1.0) return vertices, faces, vertices_colors uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution ) tex_hard_mask = tex_hard_mask.float() tex_feat = self.get_texture_prediction(planes, [gb_pos], tex_hard_mask) background_feature = torch.zeros_like(tex_feat) img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) if progress_callback is not None: progress_callback(1.0) return vertices, faces, uvs, mesh_tex_idx, texture_map