Spaces:
mashroo
/
Runtime error

CRM / model /crm /model.py
YoussefAnso's picture
Fix variable names in mesh export for CRM model to enhance clarity and maintain consistency with previous changes. Updated vertices and faces variable names for better readability.
b2dc9cc
raw
history blame
8.47 kB
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import cv2
import trimesh
import nvdiffrast.torch as dr
from model.archs.decoders.shape_texture_net import TetTexNet
from model.archs.unet import UNetPP
from util.renderer import Renderer
from model.archs.mlp_head import SdfMlp, RgbMlp
import xatlas
class Dummy:
pass
class CRM(nn.Module):
def __init__(self, specs):
super(CRM, self).__init__()
self.specs = specs
# configs
input_specs = specs["Input"]
self.input = Dummy()
self.input.scale = input_specs['scale']
self.input.resolution = input_specs['resolution']
self.tet_grid_size = input_specs['tet_grid_size']
self.camera_angle_num = input_specs['camera_angle_num']
self.arch = Dummy()
self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"]
self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"]
self.dec = Dummy()
self.dec.c_dim = specs["DecoderSpecs"]["c_dim"]
self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"]
self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex"
self.unet2 = UNetPP(in_channels=self.dec.c_dim)
mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation
self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat)
if self.geo_type == "flex":
self.weightMlp = nn.Sequential(
nn.Linear(mlp_chnl_s * 32 * 8, 512),
nn.SiLU(),
nn.Linear(512, 21))
self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
# self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num,
# scale=self.input.scale, geo_type = self.geo_type)
self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere
self.radius = specs['Pretrain']['radius'] # used when spob
self.denoising = True
from diffusers import DDIMScheduler
self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
def decode(self, data, triplane_feature2):
if self.geo_type == "flex":
tet_verts = self.renderer.flexicubes.verts.unsqueeze(0)
tet_indices = self.renderer.flexicubes.indices
dec_verts = self.decoder(triplane_feature2, tet_verts)
out = self.sdfMlp(dec_verts)
weight = None
if self.geo_type == "flex":
grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1)
grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1])
weight = self.weightMlp(grid_feat)
weight = weight * 0.1
pred_sdf, deformation = out[..., 0], out[..., 1:]
if self.spob:
pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1))
_, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight)
return verts[0].unsqueeze(0), faces[0].int()
def export_mesh(self, data, out_dir, tri_fea_2 = None):
verts = data['verts']
faces = data['faces']
dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0))
colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
# Expect predicted colors value range from [-1, 1]
colors = (colors * 0.5 + 0.5).clip(0, 1)
verts = verts[..., [0, 2, 1]]
verts[..., 0]*= -1
verts[..., 2]*= -1
verts = verts.squeeze().cpu().numpy()
faces = faces[..., [2, 1, 0]][..., [0, 2, 1]]#[..., [1, 0, 2]]
faces = faces.squeeze().cpu().numpy()#faces[..., [2, 1, 0]].squeeze().cpu().numpy()
# export the final mesh
with torch.no_grad():
mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) # important, process=True leads to seg fault...
mesh.export(f"{out_dir}.glb", file_type="glb")
def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
mesh_v = data['verts'].squeeze().cpu().numpy()
mesh_pos_idx = data['faces'].squeeze().cpu().numpy()
def interpolate(attr, rast, attr_idx, rast_db=None):
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db,
diff_attrs=None if rast_db is None else 'all')
vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx)
mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device)
mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device)
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
# mesh_v_tex. ture
uv_clip = uvs[None, ...] * 2.0 - 1.0
# pad to four component coordinate
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
# rasterize
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res)
# Interpolate world space position
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
mask = rast[..., 3:4] > 0
# return uvs, mesh_tex_idx, gb_pos, mask
gb_pos_unsqz = gb_pos.view(-1, 3)
mask_unsqz = mask.view(-1)
tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1
gb_mask_pos = gb_pos_unsqz[mask_unsqz]
gb_mask_pos = gb_mask_pos[None, ]
with torch.no_grad():
dec_verts = self.decoder(tri_fea_2, gb_mask_pos)
colors = self.rgbMlp(dec_verts).squeeze()
# Expect predicted colors value range from [-1, 1]
lo, hi = (-1, 1)
colors = (colors - lo) * (255 / (hi - lo))
colors = colors.clip(0, 255)
tex_unsqz[mask_unsqz] = colors
tex = tex_unsqz.view(res + (3,))
verts = mesh_v.squeeze().cpu().numpy()
faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy()
# faces = mesh_pos_idx
# faces = faces.detach().cpu().numpy()
# faces = faces[..., [2, 1, 0]]
indices = indices[..., [2, 1, 0]]
# xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs)
matname = f'{out_dir}.mtl'
# matname = f'{out_dir}/{ind}.mtl'
fid = open(matname, 'w')
fid.write('newmtl material_0\n')
fid.write('Kd 1 1 1\n')
fid.write('Ka 1 1 1\n')
# fid.write('Ks 0 0 0\n')
fid.write('Ks 0.4 0.4 0.4\n')
fid.write('Ns 10\n')
fid.write('illum 2\n')
fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n')
fid.close()
fid = open(f'{out_dir}.obj', 'w')
# fid = open(f'{out_dir}/{ind}.obj', 'w')
fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1])
for pidx, p in enumerate(verts):
pp = p
fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1]))
for pidx, p in enumerate(uvs):
pp = p
fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
fid.write('usemtl material_0\n')
for i, f in enumerate(faces):
f1 = f + 1
f2 = indices[i] + 1
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
fid.close()
img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32)
mask = np.sum(img.astype(float), axis=-1, keepdims=True)
mask = (mask <= 3.0).astype(float)
kernel = np.ones((3, 3), 'uint8')
dilate_img = cv2.dilate(img, kernel, iterations=1)
img = img * (1 - mask) + dilate_img * mask
img = img.clip(0, 255).astype(np.uint8)
cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
# cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])