|
import torch |
|
import lightning.pytorch as pl |
|
|
|
from torch.utils.data import DataLoader |
|
from partfield.model.UNet.model import ResidualUNet3D |
|
from partfield.model.triplane import TriplaneTransformer, get_grid_coord |
|
from partfield.model.model_utils import VanillaMLP |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
import os |
|
import trimesh |
|
import skimage |
|
import numpy as np |
|
import h5py |
|
import torch.distributed as dist |
|
from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat |
|
import json |
|
import gc |
|
import time |
|
from plyfile import PlyData, PlyElement |
|
|
|
|
|
class Model(pl.LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
self.save_hyperparameters() |
|
self.cfg = cfg |
|
self.automatic_optimization = False |
|
self.triplane_resolution = cfg.triplane_resolution |
|
self.triplane_channels_low = cfg.triplane_channels_low |
|
self.triplane_transformer = TriplaneTransformer( |
|
input_dim=cfg.triplane_channels_low * 2, |
|
transformer_dim=1024, |
|
transformer_layers=6, |
|
transformer_heads=8, |
|
triplane_low_res=32, |
|
triplane_high_res=128, |
|
triplane_dim=cfg.triplane_channels_high, |
|
) |
|
self.sdf_decoder = VanillaMLP(input_dim=64, |
|
output_dim=1, |
|
out_activation="tanh", |
|
n_neurons=64, |
|
n_hidden_layers=6) |
|
self.use_pvcnn = cfg.use_pvcnnonly |
|
self.use_2d_feat = cfg.use_2d_feat |
|
if self.use_pvcnn: |
|
self.pvcnn = TriPlanePC2Encoder( |
|
cfg.pvcnn, |
|
device="cuda", |
|
shape_min=-1, |
|
shape_length=2, |
|
use_2d_feat=self.use_2d_feat) |
|
self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True)) |
|
self.grid_coord = get_grid_coord(256) |
|
self.mse_loss = torch.nn.MSELoss() |
|
self.l1_loss = torch.nn.L1Loss(reduction='none') |
|
|
|
if cfg.regress_2d_feat: |
|
self.feat_decoder = VanillaMLP(input_dim=64, |
|
output_dim=192, |
|
out_activation="GELU", |
|
n_neurons=64, |
|
n_hidden_layers=6) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def encode(self, points): |
|
|
|
N = points.shape[0] |
|
|
|
pcd = points[..., :3] |
|
|
|
pc_feat = self.pvcnn(pcd, pcd) |
|
|
|
planes = pc_feat |
|
planes = self.triplane_transformer(planes) |
|
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) |
|
|
|
tensor_vertices = pcd.reshape(N, -1, 3).cuda().to(pcd.dtype) |
|
point_feat = sample_triplane_feat(part_planes, tensor_vertices) |
|
|
|
point_feat = point_feat.reshape(N, -1, 448) |
|
|
|
return point_feat |