Spaces:
mashroo
/
Running on Zero

CRM / inference.py
YoussefAnso's picture
Refactor texture baking logic in generate3d function of inference.py to utilize per-vertex colors for improved texture quality. Updated the method for filling the texture image to enhance the visual output in mesh exports.
37c1f6f
raw
history blame
3.52 kB
import numpy as np
import torch
import time
import tempfile
import zipfile
import nvdiffrast.torch as dr
import xatlas
import cv2
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
from kiui.mesh_utils import clean_mesh
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type=model.geo_type)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).contiguous()
data_config['faces'] = torch.from_numpy(faces).contiguous()
# CPU-only UV unwrapping with xatlas
mesh_v = data_config['verts'].cpu().numpy()
mesh_f = data_config['faces'].cpu().numpy()
vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)
# Use per-vertex colors if available, else fallback to white
vertex_colors = np.ones((mesh_v.shape[0], 3), dtype=np.float32) # fallback: white
# If you have per-vertex color, you can assign here, e.g.:
# vertex_colors = ...
# Bake vertex colors to texture in UV space
tex_res = (1024, 1024)
texture = np.zeros((tex_res[1], tex_res[0], 3), dtype=np.float32)
vt_img = (vt * np.array(tex_res)).astype(np.int32)
for face, uv_idx in zip(mesh_f, ft):
pts = vt_img[uv_idx]
color = vertex_colors[face].mean(axis=0)
cv2.fillPoly(texture, [pts], color.tolist())
texture = np.clip(texture, 0, 1)
# Create Mesh and export .glb
mesh = Mesh(
v=torch.from_numpy(mesh_v).float(),
f=torch.from_numpy(mesh_f).int(),
vt=torch.from_numpy(vt).float(),
ft=torch.from_numpy(ft).int(),
albedo=torch.from_numpy(texture).float()
)
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
mesh.write(temp_path)
return temp_path