Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Refactor generate3d function in inference.py to implement CPU-only UV unwrapping using xatlas and trimesh for improved texture mapping. Updated texture baking logic to utilize the model's decoder and rgbMlp, enhancing the quality of generated textures in mesh exports.
f6ac519
raw
history blame
4.19 kB
import numpy as np
import torch
import time
import tempfile
import zipfile
import nvdiffrast.torch as dr
import xatlas
import cv2
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
from kiui.mesh_utils import clean_mesh
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type=model.geo_type)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).contiguous()
data_config['faces'] = torch.from_numpy(faces).contiguous()
# CPU-only UV unwrapping with xatlas
mesh_v = data_config['verts'].cpu().numpy()
mesh_f = data_config['faces'].cpu().numpy()
vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)
# Bake texture using model's decoder and rgbMlp (CPU-only)
tex_res = (1024, 1024)
# Generate a grid of UV coordinates
uv_grid = np.stack(np.meshgrid(
np.linspace(0, 1, tex_res[0]),
np.linspace(0, 1, tex_res[1])
), -1).reshape(-1, 2) # (H*W, 2)
# Map UVs to 3D positions using barycentric interpolation of mesh faces
# For simplicity, we'll sample random points on the mesh surface and use their UVs
# (A more advanced approach would rasterize each face, but this is a CPU-friendly approximation)
mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
# Use the model's decoder and rgbMlp to get color for each sampled point
points_tensor = torch.from_numpy(points).float().unsqueeze(0) # (1, N, 3)
with torch.no_grad():
dec_verts = model.decoder(triplane_feature2, points_tensor)
colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)
colors = (colors * 0.5 + 0.5).clip(0, 1)
# Fill the texture image
texture = np.zeros((tex_res[1]*tex_res[0], 3), dtype=np.float32)
texture[:colors.shape[0]] = colors
texture = texture.reshape(tex_res[1], tex_res[0], 3)
texture = np.clip(texture, 0, 1)
# Create Mesh and export .glb
mesh = Mesh(
v=torch.from_numpy(mesh_v).float(),
f=torch.from_numpy(mesh_f).int(),
vt=torch.from_numpy(vt).float(),
ft=torch.from_numpy(ft).int(),
albedo=torch.from_numpy(texture).float()
)
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
mesh.write(temp_path)
return temp_path