|
import torch |
|
import lightning.pytorch as pl |
|
from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset |
|
from torch.utils.data import DataLoader |
|
from partfield.model.UNet.model import ResidualUNet3D |
|
from partfield.model.triplane import TriplaneTransformer, get_grid_coord |
|
from partfield.model.model_utils import VanillaMLP |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
import os |
|
import trimesh |
|
import skimage |
|
import numpy as np |
|
import h5py |
|
import torch.distributed as dist |
|
from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat |
|
import json |
|
import gc |
|
import time |
|
from plyfile import PlyData, PlyElement |
|
|
|
|
|
class Model(pl.LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
self.save_hyperparameters() |
|
self.cfg = cfg |
|
self.automatic_optimization = False |
|
self.triplane_resolution = cfg.triplane_resolution |
|
self.triplane_channels_low = cfg.triplane_channels_low |
|
self.triplane_transformer = TriplaneTransformer( |
|
input_dim=cfg.triplane_channels_low * 2, |
|
transformer_dim=1024, |
|
transformer_layers=6, |
|
transformer_heads=8, |
|
triplane_low_res=32, |
|
triplane_high_res=128, |
|
triplane_dim=cfg.triplane_channels_high, |
|
) |
|
self.sdf_decoder = VanillaMLP(input_dim=64, |
|
output_dim=1, |
|
out_activation="tanh", |
|
n_neurons=64, |
|
n_hidden_layers=6) |
|
self.use_pvcnn = cfg.use_pvcnnonly |
|
self.use_2d_feat = cfg.use_2d_feat |
|
if self.use_pvcnn: |
|
self.pvcnn = TriPlanePC2Encoder( |
|
cfg.pvcnn, |
|
device="cuda", |
|
shape_min=-1, |
|
shape_length=2, |
|
use_2d_feat=self.use_2d_feat) |
|
self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True)) |
|
self.grid_coord = get_grid_coord(256) |
|
self.mse_loss = torch.nn.MSELoss() |
|
self.l1_loss = torch.nn.L1Loss(reduction='none') |
|
|
|
if cfg.regress_2d_feat: |
|
self.feat_decoder = VanillaMLP(input_dim=64, |
|
output_dim=192, |
|
out_activation="GELU", |
|
n_neurons=64, |
|
n_hidden_layers=6) |
|
|
|
def predict_dataloader(self): |
|
if self.cfg.remesh_demo: |
|
dataset = Demo_Remesh_Dataset(self.cfg) |
|
elif self.cfg.correspondence_demo: |
|
dataset = Correspondence_Demo_Dataset(self.cfg) |
|
else: |
|
dataset = Demo_Dataset(self.cfg) |
|
|
|
dataloader = DataLoader(dataset, |
|
num_workers=self.cfg.dataset.val_num_workers, |
|
batch_size=self.cfg.dataset.val_batch_size, |
|
shuffle=False, |
|
pin_memory=True, |
|
drop_last=False) |
|
|
|
return dataloader |
|
|
|
|
|
@torch.no_grad() |
|
def predict_step(self, batch, batch_idx): |
|
save_dir = f"{self.cfg.result_name}" |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
uid = batch['uid'][0] |
|
view_id = 0 |
|
starttime = time.time() |
|
|
|
if uid == "car" or uid == "complex_car": |
|
|
|
print("Skipping this for now.") |
|
print(uid) |
|
return |
|
|
|
|
|
if os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}.npy') or os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy'): |
|
print("Already processed "+uid) |
|
return |
|
|
|
N = batch['pc'].shape[0] |
|
assert N == 1 |
|
|
|
if self.use_2d_feat: |
|
print("ERROR. Dataloader not implemented with input 2d feat.") |
|
exit() |
|
else: |
|
pc_feat = self.pvcnn(batch['pc'], batch['pc']) |
|
|
|
planes = pc_feat |
|
planes = self.triplane_transformer(planes) |
|
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) |
|
|
|
if self.cfg.is_pc: |
|
tensor_vertices = batch['pc'].reshape(1, -1, 3).cuda().to(torch.float16) |
|
point_feat = sample_triplane_feat(part_planes, tensor_vertices) |
|
point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448) |
|
|
|
np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat) |
|
print(f"Exported part_feat_{uid}_{view_id}.npy") |
|
|
|
|
|
from sklearn.decomposition import PCA |
|
data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) |
|
|
|
pca = PCA(n_components=3) |
|
|
|
data_reduced = pca.fit_transform(data_scaled) |
|
data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) |
|
colors_255 = (data_reduced * 255).astype(np.uint8) |
|
|
|
points = batch['pc'].squeeze().detach().cpu().numpy() |
|
|
|
if colors_255 is None: |
|
colors_255 = np.full_like(points, 255) |
|
else: |
|
assert colors_255.shape == points.shape, "Colors must have the same shape as points" |
|
|
|
|
|
vertex_data = np.array( |
|
[(*point, *color) for point, color in zip(points, colors_255)], |
|
dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")] |
|
) |
|
|
|
|
|
el = PlyElement.describe(vertex_data, "vertex") |
|
|
|
filename = f'{save_dir}/feat_pca_{uid}_{view_id}.ply' |
|
PlyData([el], text=True).write(filename) |
|
print(f"Saved PLY file: {filename}") |
|
|
|
|
|
else: |
|
use_cuda_version = True |
|
if use_cuda_version: |
|
|
|
def sample_points(vertices, faces, n_point_per_face): |
|
|
|
|
|
n_f = faces.shape[0] |
|
u = torch.sqrt(torch.rand((n_f, n_point_per_face, 1), |
|
device=vertices.device, |
|
dtype=vertices.dtype)) |
|
v = torch.rand((n_f, n_point_per_face, 1), |
|
device=vertices.device, |
|
dtype=vertices.dtype) |
|
w0 = 1 - u |
|
w1 = u * (1 - v) |
|
w2 = u * v |
|
|
|
face_v_0 = torch.index_select(vertices, 0, faces[:, 0].reshape(-1)) |
|
face_v_1 = torch.index_select(vertices, 0, faces[:, 1].reshape(-1)) |
|
face_v_2 = torch.index_select(vertices, 0, faces[:, 2].reshape(-1)) |
|
points = w0 * face_v_0.unsqueeze(dim=1) + w1 * face_v_1.unsqueeze(dim=1) + w2 * face_v_2.unsqueeze(dim=1) |
|
return points |
|
|
|
def sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face): |
|
n_sample_each = self.cfg.n_sample_each |
|
n_v = tensor_vertices.shape[1] |
|
n_sample = n_v // n_sample_each + 1 |
|
all_sample = [] |
|
for i_sample in range(n_sample): |
|
sampled_feature = sample_triplane_feat(part_planes, tensor_vertices[:, i_sample * n_sample_each: i_sample * n_sample_each + n_sample_each,]) |
|
assert sampled_feature.shape[1] % n_point_per_face == 0 |
|
sampled_feature = sampled_feature.reshape(1, -1, n_point_per_face, sampled_feature.shape[-1]) |
|
sampled_feature = torch.mean(sampled_feature, axis=-2) |
|
all_sample.append(sampled_feature) |
|
return torch.cat(all_sample, dim=1) |
|
|
|
if self.cfg.vertex_feature: |
|
tensor_vertices = batch['vertices'][0].reshape(1, -1, 3).to(torch.float32) |
|
point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, 1) |
|
else: |
|
n_point_per_face = self.cfg.n_point_per_face |
|
tensor_vertices = sample_points(batch['vertices'][0], batch['faces'][0], n_point_per_face) |
|
tensor_vertices = tensor_vertices.reshape(1, -1, 3).to(torch.float32) |
|
point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face) |
|
|
|
|
|
print("Time elapsed for feature prediction: " + str(time.time() - starttime)) |
|
point_feat = point_feat.reshape(-1, 448).cpu().numpy() |
|
np.save(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy', point_feat) |
|
print(f"Exported part_feat_{uid}_{view_id}.npy") |
|
|
|
|
|
from sklearn.decomposition import PCA |
|
data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) |
|
|
|
pca = PCA(n_components=3) |
|
|
|
data_reduced = pca.fit_transform(data_scaled) |
|
data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) |
|
colors_255 = (data_reduced * 255).astype(np.uint8) |
|
V = batch['vertices'][0].cpu().numpy() |
|
F = batch['faces'][0].cpu().numpy() |
|
if self.cfg.vertex_feature: |
|
colored_mesh = trimesh.Trimesh(vertices=V, faces=F, vertex_colors=colors_255, process=False) |
|
else: |
|
colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False) |
|
colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply') |
|
|
|
torch.cuda.empty_cache() |
|
|
|
else: |
|
|
|
V = batch['vertices'][0].cpu().numpy() |
|
F = batch['faces'][0].cpu().numpy() |
|
|
|
|
|
num_samples_per_face = self.cfg.n_point_per_face |
|
|
|
all_point_feats = [] |
|
for face in F: |
|
|
|
v0, v1, v2 = V[face] |
|
|
|
|
|
u = np.random.rand(num_samples_per_face, 1) |
|
v = np.random.rand(num_samples_per_face, 1) |
|
is_prob = (u+v) >1 |
|
u[is_prob] = 1 - u[is_prob] |
|
v[is_prob] = 1 - v[is_prob] |
|
w = 1 - u - v |
|
|
|
|
|
points = u * v0 + v * v1 + w * v2 |
|
|
|
tensor_vertices = torch.from_numpy(points.copy()).reshape(1, -1, 3).cuda().to(torch.float32) |
|
point_feat = sample_triplane_feat(part_planes, tensor_vertices) |
|
|
|
|
|
point_feat = torch.mean(point_feat, axis=1).cpu().detach().numpy() |
|
all_point_feats.append(point_feat) |
|
|
|
|
|
all_point_feats = np.array(all_point_feats).reshape(-1, 448) |
|
|
|
point_feat = all_point_feats |
|
|
|
np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat) |
|
print(f"Exported part_feat_{uid}_{view_id}.npy") |
|
|
|
|
|
from sklearn.decomposition import PCA |
|
data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) |
|
|
|
pca = PCA(n_components=3) |
|
|
|
data_reduced = pca.fit_transform(data_scaled) |
|
data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) |
|
colors_255 = (data_reduced * 255).astype(np.uint8) |
|
|
|
colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False) |
|
colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply') |
|
|
|
|
|
print("Time elapsed: " + str(time.time()-starttime)) |
|
|
|
return |