Spaces:
mashroo
/
Running on Zero

CRM / inference.py
YoussefAnso's picture
Refactor generate3d function in inference.py to improve mesh export process. Updated OBJ and texture file checks, streamlined GLB conversion with embedded texture handling, and enhanced error handling for missing files. This change ensures proper export of textured GLB files.
30b6055
raw
history blame
3.37 kB
import numpy as np
import torch
import time
import tempfile
import os
from PIL import Image
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(
tet_grid_size=model.tet_grid_size,
camera_angle_num=model.camera_angle_num,
scale=model.input.scale,
geo_type=model.geo_type
)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(img_tensor):
images = []
images.append(img_tensor[:, :, 256*5:256*(1+5)])
for i in range(5):
images.append(img_tensor[:, :, 256*i:256*(i+1)])
return torch.stack(images, dim=0) # [6, C, H, W]
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
# Export to .obj with UV and texture image
mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
model.export_mesh_wt_uv(
None, # GL context skipped (no dr.RasterizeGLContext)
data_config,
mesh_path_obj,
"",
device,
res=(1024, 1024),
tri_fea_2=triplane_feature2
)
# Convert to GLB with embedded texture
obj_path = mesh_path_obj + ".obj"
texture_path = mesh_path_obj + ".png"
if not os.path.exists(obj_path) or not os.path.exists(texture_path):
raise RuntimeError("OBJ or texture file missing — cannot convert to GLB.")
mesh_scene = trimesh.load(obj_path, force='scene')
texture_image = np.array(Image.open(texture_path))
for name, geom in mesh_scene.geometry.items():
geom.visual = trimesh.visual.texture.TextureVisuals(image=texture_image)
glb_path = mesh_path_obj + "_final.glb"
mesh_scene.export(glb_path)
return glb_path