File size: 4,267 Bytes
d63ad23 a14c9ce d454202 9ee53e8 854cd53 f6ac519 d454202 d63ad23 5777f44 d454202 5777f44 f4e8cf6 dfab55e a14c9ce dfab55e cb29219 a14c9ce f4e8cf6 dfab55e f4e8cf6 dfab55e f4e8cf6 a14c9ce dfab55e a14c9ce 76eeb7d cb29219 a14c9ce 76eeb7d f4e8cf6 a14c9ce cb29219 a14c9ce dfab55e 9ee53e8 76eeb7d f6ac519 dcb102c f6ac519 dfab55e 9ee53e8 f6ac519 9ee53e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import numpy as np
import torch
import time
import tempfile
import zipfile
import nvdiffrast.torch as dr
import xatlas
import cv2
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
from kiui.mesh_utils import clean_mesh
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type=model.geo_type)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).contiguous()
data_config['faces'] = torch.from_numpy(faces).contiguous()
# CPU-only UV unwrapping with xatlas
mesh_v = data_config['verts'].cpu().numpy()
mesh_f = data_config['faces'].cpu().numpy()
vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)
# Bake texture using model's decoder and rgbMlp (CPU-only)
tex_res = (1024, 1024)
# Generate a grid of UV coordinates
uv_grid = np.stack(np.meshgrid(
np.linspace(0, 1, tex_res[0]),
np.linspace(0, 1, tex_res[1])
), -1).reshape(-1, 2) # (H*W, 2)
# Map UVs to 3D positions using barycentric interpolation of mesh faces
# For simplicity, we'll sample random points on the mesh surface and use their UVs
# (A more advanced approach would rasterize each face, but this is a CPU-friendly approximation)
mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
# Use the model's decoder and rgbMlp to get color for each sampled point
points_tensor = torch.from_numpy(points).float().unsqueeze(0).cpu() # (1, N, 3) on CPU
triplane_feature2 = triplane_feature2.cpu() # Ensure on CPU
with torch.no_grad():
dec_verts = model.decoder(triplane_feature2, points_tensor)
colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)
colors = (colors * 0.5 + 0.5).clip(0, 1)
# Fill the texture image
texture = np.zeros((tex_res[1]*tex_res[0], 3), dtype=np.float32)
texture[:colors.shape[0]] = colors
texture = texture.reshape(tex_res[1], tex_res[0], 3)
texture = np.clip(texture, 0, 1)
# Create Mesh and export .glb
mesh = Mesh(
v=torch.from_numpy(mesh_v).float(),
f=torch.from_numpy(mesh_f).int(),
vt=torch.from_numpy(vt).float(),
ft=torch.from_numpy(ft).int(),
albedo=torch.from_numpy(texture).float()
)
temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
mesh.write(temp_path)
return temp_path
|