Refactor generate3d function in inference.py to implement CCM-based UV assignment for improved texture mapping. Updated color map preparation and vertex projection to enhance mesh quality in exports, ensuring compatibility with various input formats.
e666402
import numpy as np | |
import torch | |
import time | |
import tempfile | |
import zipfile | |
import nvdiffrast.torch as dr | |
import xatlas | |
import cv2 | |
from util.utils import get_tri | |
from mesh import Mesh | |
from util.renderer import Renderer | |
from kiui.mesh_utils import clean_mesh | |
def generate3d(model, rgb, ccm, device): | |
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num, | |
scale=model.input.scale, geo_type=model.geo_type) | |
color_tri = torch.from_numpy(rgb) / 255 | |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255 | |
color = color_tri.permute(2, 0, 1) | |
xyz = xyz_tri.permute(2, 0, 1) | |
def get_imgs(color): | |
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0) | |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device) | |
color = get_imgs(color) | |
xyz = get_imgs(xyz) | |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0) | |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0) | |
triplane = torch.cat([color, xyz], dim=1).to(device) | |
model.eval() | |
if model.denoising: | |
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device) | |
noise_new = torch.randn_like(triplane) * 0.5 + 0.5 | |
triplane = model.scheduler.add_noise(triplane, noise_new, tnew) | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane, tnew) | |
else: | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane) | |
data_config = { | |
'resolution': [1024, 1024], | |
"triview_color": triplane_color.to(device), | |
} | |
with torch.no_grad(): | |
verts, faces = model.decode(data_config, triplane_feature2) | |
data_config['verts'] = verts[0] | |
data_config['faces'] = faces | |
verts, faces = clean_mesh( | |
data_config['verts'].squeeze().cpu().numpy().astype(np.float32), | |
data_config['faces'].squeeze().cpu().numpy().astype(np.int32), | |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1 | |
) | |
data_config['verts'] = torch.from_numpy(verts).contiguous() | |
data_config['faces'] = torch.from_numpy(faces).contiguous() | |
# --- CCM-based UV assignment --- | |
mesh_v = data_config['verts'].cpu().numpy() # [N, 3] | |
mesh_f = data_config['faces'].cpu().numpy() # [M, 3] | |
# Prepare CCM and color map | |
ccm_img = ccm.astype(np.uint8) if ccm.max() > 1 else (ccm * 255).astype(np.uint8) | |
if ccm_img.shape[-1] != 3: | |
ccm_img = np.transpose(ccm_img, (1, 2, 0)) | |
H, W, _ = ccm_img.shape | |
color_map = rgb.astype(np.uint8) if rgb.max() > 1 else (rgb * 255).astype(np.uint8) | |
if color_map.shape[-1] != 3: | |
color_map = np.transpose(color_map, (1, 2, 0)) | |
albedo = cv2.cvtColor(color_map, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
# Project mesh vertices to CCM image space and get UVs | |
vt = [] | |
for v in mesh_v: | |
# Assume mesh is in [-1,1] in x/y, project to CCM image | |
x, y, z = v | |
u_img = int((x + 1) / 2 * (W - 1)) | |
v_img = int((y + 1) / 2 * (H - 1)) | |
u_img = np.clip(u_img, 0, W-1) | |
v_img = np.clip(v_img, 0, H-1) | |
r, g, b = ccm_img[v_img, u_img] | |
u = r / 255.0 | |
v_ = g / 255.0 | |
vt.append([u, v_]) | |
vt = np.array(vt, dtype=np.float32) | |
ft = mesh_f.copy() | |
# Create Mesh and export .glb | |
mesh = Mesh( | |
v=torch.from_numpy(mesh_v).float(), | |
f=torch.from_numpy(mesh_f).int(), | |
vt=torch.from_numpy(vt).float(), | |
ft=torch.from_numpy(ft).int(), | |
albedo=torch.from_numpy(albedo).float() | |
) | |
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name | |
mesh.write(temp_path) | |
return temp_path | |