Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Ensure tensors are on CPU in generate3d function of inference.py for consistent processing. Updated points_tensor and triplane_feature2 to prevent device-related issues during inference.
dcb102c
raw
history blame
4.27 kB
import numpy as np
import torch
import time
import tempfile
import zipfile
import nvdiffrast.torch as dr
import xatlas
import cv2
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
from kiui.mesh_utils import clean_mesh
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type=model.geo_type)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).contiguous()
data_config['faces'] = torch.from_numpy(faces).contiguous()
# CPU-only UV unwrapping with xatlas
mesh_v = data_config['verts'].cpu().numpy()
mesh_f = data_config['faces'].cpu().numpy()
vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)
# Bake texture using model's decoder and rgbMlp (CPU-only)
tex_res = (1024, 1024)
# Generate a grid of UV coordinates
uv_grid = np.stack(np.meshgrid(
np.linspace(0, 1, tex_res[0]),
np.linspace(0, 1, tex_res[1])
), -1).reshape(-1, 2) # (H*W, 2)
# Map UVs to 3D positions using barycentric interpolation of mesh faces
# For simplicity, we'll sample random points on the mesh surface and use their UVs
# (A more advanced approach would rasterize each face, but this is a CPU-friendly approximation)
mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
# Use the model's decoder and rgbMlp to get color for each sampled point
points_tensor = torch.from_numpy(points).float().unsqueeze(0).cpu() # (1, N, 3) on CPU
triplane_feature2 = triplane_feature2.cpu() # Ensure on CPU
with torch.no_grad():
dec_verts = model.decoder(triplane_feature2, points_tensor)
colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)
colors = (colors * 0.5 + 0.5).clip(0, 1)
# Fill the texture image
texture = np.zeros((tex_res[1]*tex_res[0], 3), dtype=np.float32)
texture[:colors.shape[0]] = colors
texture = texture.reshape(tex_res[1], tex_res[0], 3)
texture = np.clip(texture, 0, 1)
# Create Mesh and export .glb
mesh = Mesh(
v=torch.from_numpy(mesh_v).float(),
f=torch.from_numpy(mesh_f).int(),
vt=torch.from_numpy(vt).float(),
ft=torch.from_numpy(ft).int(),
albedo=torch.from_numpy(texture).float()
)
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
mesh.write(temp_path)
return temp_path