Refactor generate3d function in inference.py to enhance mesh processing and export logic. Streamlined image tensor handling, improved denoising integration, and updated mesh export to include UV mapping. Ensured proper handling of temporary files for OBJ and GLB formats, enhancing overall readability and maintainability.
a14c9ce
from kiui.mesh_utils import clean_mesh | |
import trimesh | |
import zipfile | |
import tempfile | |
import os | |
import nvdiffrast.torch as dr | |
def generate3d(model, rgb, ccm, device): | |
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num, | |
scale=model.input.scale, geo_type=model.geo_type) | |
color_tri = torch.from_numpy(rgb) / 255 | |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255 | |
color = color_tri.permute(2, 0, 1) | |
xyz = xyz_tri.permute(2, 0, 1) | |
def get_imgs(color): | |
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0) | |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device) | |
color = get_imgs(color) | |
xyz = get_imgs(xyz) | |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0) | |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0) | |
triplane = torch.cat([color, xyz], dim=1).to(device) | |
model.eval() | |
if model.denoising: | |
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device) | |
noise_new = torch.randn_like(triplane) * 0.5 + 0.5 | |
triplane = model.scheduler.add_noise(triplane, noise_new, tnew) | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane, tnew) | |
else: | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane) | |
data_config = { | |
'resolution': [1024, 1024], | |
"triview_color": triplane_color.to(device), | |
} | |
with torch.no_grad(): | |
verts, faces = model.decode(data_config, triplane_feature2) | |
data_config['verts'] = verts[0] | |
data_config['faces'] = faces | |
verts, faces = clean_mesh( | |
data_config['verts'].squeeze().cpu().numpy().astype(np.float32), | |
data_config['faces'].squeeze().cpu().numpy().astype(np.int32), | |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1 | |
) | |
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous() | |
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous() | |
# Create base filename | |
temp_path = tempfile.NamedTemporaryFile(suffix="", delete=False).name | |
obj_base = temp_path # no extension | |
# Export mesh with UV and PNG | |
glctx = dr.RasterizeCudaContext() | |
model.export_mesh_wt_uv( | |
glctx, data_config, obj_base, "", device, res=(1024, 1024), tri_fea_2=triplane_feature2 | |
) | |
# Load .obj with texture and export .glb | |
mesh = trimesh.load(obj_base + ".obj", process=False) | |
mesh.export(obj_base + ".glb") | |
return obj_base + ".glb" | |