File size: 3,842 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
import lightning.pytorch as pl
# from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset
from torch.utils.data import DataLoader
from partfield.model.UNet.model import ResidualUNet3D
from partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane
from partfield.model.model_utils import VanillaMLP
import torch.nn.functional as F
import torch.nn as nn
import os
import trimesh
import skimage
import numpy as np
import h5py
import torch.distributed as dist
from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat
import json
import gc
import time
from plyfile import PlyData, PlyElement
class Model(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.save_hyperparameters()
self.cfg = cfg
self.automatic_optimization = False
self.triplane_resolution = cfg.triplane_resolution
self.triplane_channels_low = cfg.triplane_channels_low
self.triplane_transformer = TriplaneTransformer(
input_dim=cfg.triplane_channels_low * 2,
transformer_dim=1024,
transformer_layers=6,
transformer_heads=8,
triplane_low_res=32,
triplane_high_res=128,
triplane_dim=cfg.triplane_channels_high,
)
self.sdf_decoder = VanillaMLP(input_dim=64,
output_dim=1,
out_activation="tanh",
n_neurons=64, #64
n_hidden_layers=6) #6
self.use_pvcnn = cfg.use_pvcnnonly
self.use_2d_feat = cfg.use_2d_feat
if self.use_pvcnn:
self.pvcnn = TriPlanePC2Encoder(
cfg.pvcnn,
device="cuda",
shape_min=-1,
shape_length=2,
use_2d_feat=self.use_2d_feat) #.cuda()
self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True))
self.grid_coord = get_grid_coord(256)
self.mse_loss = torch.nn.MSELoss()
self.l1_loss = torch.nn.L1Loss(reduction='none')
if cfg.regress_2d_feat:
self.feat_decoder = VanillaMLP(input_dim=64,
output_dim=192,
out_activation="GELU",
n_neurons=64, #64
n_hidden_layers=6) #6
# def predict_dataloader(self):
# if self.cfg.remesh_demo:
# dataset = Demo_Remesh_Dataset(self.cfg)
# elif self.cfg.correspondence_demo:
# dataset = Correspondence_Demo_Dataset(self.cfg)
# else:
# dataset = Demo_Dataset(self.cfg)
# dataloader = DataLoader(dataset,
# num_workers=self.cfg.dataset.val_num_workers,
# batch_size=self.cfg.dataset.val_batch_size,
# shuffle=False,
# pin_memory=True,
# drop_last=False)
# return dataloader
@torch.no_grad()
def encode(self, points):
N = points.shape[0]
# assert N == 1
pcd = points[..., :3]
pc_feat = self.pvcnn(pcd, pcd)
planes = pc_feat
planes = self.triplane_transformer(planes)
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
tensor_vertices = pcd.reshape(N, -1, 3).cuda().to(pcd.dtype)
point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
# point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448)
point_feat = point_feat.reshape(N, -1, 448)
return point_feat |