Refactor generate3d function in inference.py to implement CPU-only UV unwrapping using xatlas and trimesh for improved texture mapping. Updated texture baking logic to utilize the model's decoder and rgbMlp, enhancing the quality of generated textures in mesh exports.
f6ac519
import numpy as np | |
import torch | |
import time | |
import tempfile | |
import zipfile | |
import nvdiffrast.torch as dr | |
import xatlas | |
import cv2 | |
import trimesh | |
from util.utils import get_tri | |
from mesh import Mesh | |
from util.renderer import Renderer | |
from kiui.mesh_utils import clean_mesh | |
def generate3d(model, rgb, ccm, device): | |
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num, | |
scale=model.input.scale, geo_type=model.geo_type) | |
color_tri = torch.from_numpy(rgb) / 255 | |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255 | |
color = color_tri.permute(2, 0, 1) | |
xyz = xyz_tri.permute(2, 0, 1) | |
def get_imgs(color): | |
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0) | |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device) | |
color = get_imgs(color) | |
xyz = get_imgs(xyz) | |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0) | |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0) | |
triplane = torch.cat([color, xyz], dim=1).to(device) | |
model.eval() | |
if model.denoising: | |
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device) | |
noise_new = torch.randn_like(triplane) * 0.5 + 0.5 | |
triplane = model.scheduler.add_noise(triplane, noise_new, tnew) | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane, tnew) | |
else: | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane) | |
data_config = { | |
'resolution': [1024, 1024], | |
"triview_color": triplane_color.to(device), | |
} | |
with torch.no_grad(): | |
verts, faces = model.decode(data_config, triplane_feature2) | |
data_config['verts'] = verts[0] | |
data_config['faces'] = faces | |
verts, faces = clean_mesh( | |
data_config['verts'].squeeze().cpu().numpy().astype(np.float32), | |
data_config['faces'].squeeze().cpu().numpy().astype(np.int32), | |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1 | |
) | |
data_config['verts'] = torch.from_numpy(verts).contiguous() | |
data_config['faces'] = torch.from_numpy(faces).contiguous() | |
# CPU-only UV unwrapping with xatlas | |
mesh_v = data_config['verts'].cpu().numpy() | |
mesh_f = data_config['faces'].cpu().numpy() | |
vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f) | |
# Bake texture using model's decoder and rgbMlp (CPU-only) | |
tex_res = (1024, 1024) | |
# Generate a grid of UV coordinates | |
uv_grid = np.stack(np.meshgrid( | |
np.linspace(0, 1, tex_res[0]), | |
np.linspace(0, 1, tex_res[1]) | |
), -1).reshape(-1, 2) # (H*W, 2) | |
# Map UVs to 3D positions using barycentric interpolation of mesh faces | |
# For simplicity, we'll sample random points on the mesh surface and use their UVs | |
# (A more advanced approach would rasterize each face, but this is a CPU-friendly approximation) | |
mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False) | |
points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1]) | |
# Use the model's decoder and rgbMlp to get color for each sampled point | |
points_tensor = torch.from_numpy(points).float().unsqueeze(0) # (1, N, 3) | |
with torch.no_grad(): | |
dec_verts = model.decoder(triplane_feature2, points_tensor) | |
colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3) | |
colors = (colors * 0.5 + 0.5).clip(0, 1) | |
# Fill the texture image | |
texture = np.zeros((tex_res[1]*tex_res[0], 3), dtype=np.float32) | |
texture[:colors.shape[0]] = colors | |
texture = texture.reshape(tex_res[1], tex_res[0], 3) | |
texture = np.clip(texture, 0, 1) | |
# Create Mesh and export .glb | |
mesh = Mesh( | |
v=torch.from_numpy(mesh_v).float(), | |
f=torch.from_numpy(mesh_f).int(), | |
vt=torch.from_numpy(vt).float(), | |
ft=torch.from_numpy(ft).int(), | |
albedo=torch.from_numpy(texture).float() | |
) | |
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name | |
mesh.write(temp_path) | |
return temp_path | |