Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Refactor generate3d function in inference.py to implement CCM-based UV assignment for improved texture mapping. Updated color map preparation and vertex projection to enhance mesh quality in exports, ensuring compatibility with various input formats.
e666402
raw
history blame
3.81 kB
import numpy as np
import torch
import time
import tempfile
import zipfile
import nvdiffrast.torch as dr
import xatlas
import cv2
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
from kiui.mesh_utils import clean_mesh
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type=model.geo_type)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).contiguous()
data_config['faces'] = torch.from_numpy(faces).contiguous()
# --- CCM-based UV assignment ---
mesh_v = data_config['verts'].cpu().numpy() # [N, 3]
mesh_f = data_config['faces'].cpu().numpy() # [M, 3]
# Prepare CCM and color map
ccm_img = ccm.astype(np.uint8) if ccm.max() > 1 else (ccm * 255).astype(np.uint8)
if ccm_img.shape[-1] != 3:
ccm_img = np.transpose(ccm_img, (1, 2, 0))
H, W, _ = ccm_img.shape
color_map = rgb.astype(np.uint8) if rgb.max() > 1 else (rgb * 255).astype(np.uint8)
if color_map.shape[-1] != 3:
color_map = np.transpose(color_map, (1, 2, 0))
albedo = cv2.cvtColor(color_map, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
# Project mesh vertices to CCM image space and get UVs
vt = []
for v in mesh_v:
# Assume mesh is in [-1,1] in x/y, project to CCM image
x, y, z = v
u_img = int((x + 1) / 2 * (W - 1))
v_img = int((y + 1) / 2 * (H - 1))
u_img = np.clip(u_img, 0, W-1)
v_img = np.clip(v_img, 0, H-1)
r, g, b = ccm_img[v_img, u_img]
u = r / 255.0
v_ = g / 255.0
vt.append([u, v_])
vt = np.array(vt, dtype=np.float32)
ft = mesh_f.copy()
# Create Mesh and export .glb
mesh = Mesh(
v=torch.from_numpy(mesh_v).float(),
f=torch.from_numpy(mesh_f).int(),
vt=torch.from_numpy(vt).float(),
ft=torch.from_numpy(ft).int(),
albedo=torch.from_numpy(albedo).float()
)
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
mesh.write(temp_path)
return temp_path