Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Refactor texture baking in vertex_color_to_uv_textured_glb function to use vectorized operations. This change improves performance by precomputing face data and applying vectorized barycentric tests, enhancing the efficiency of texture generation.
44c19b3
raw
history blame
6.97 kB
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from mesh import Mesh
import zipfile
from util.renderer import Renderer
import trimesh
import xatlas
import cv2
from PIL import Image, ImageFilter
def vertex_color_to_uv_textured_glb(obj_path, glb_path, texture_size=512):
mesh = trimesh.load(obj_path, process=False)
vertex_colors = mesh.visual.vertex_colors[:, :3] # (N, 3), uint8
# Generate UVs
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
vertices = mesh.vertices[vmapping]
vertex_colors = vertex_colors[vmapping]
mesh.vertices = vertices
mesh.faces = indices
# Bake texture (vectorized)
buffer_size = texture_size * 2
texture_buffer = np.zeros((buffer_size, buffer_size, 4), dtype=np.uint8)
# Precompute face data
face_uvs = uvs[mesh.faces]
face_colors = vertex_colors[mesh.faces]
# Compute bounding boxes for all faces
min_xy = np.floor(np.min(face_uvs, axis=1) * (buffer_size - 1)).astype(int)
max_xy = np.ceil(np.max(face_uvs, axis=1) * (buffer_size - 1)).astype(int)
for i in range(len(mesh.faces)):
uv0, uv1, uv2 = face_uvs[i]
c0, c1, c2 = face_colors[i]
min_x, min_y = min_xy[i]
max_x, max_y = max_xy[i]
# Create a grid of pixel coordinates in the bounding box
xs = np.arange(min_x, max_x + 1)
ys = np.arange(min_y, max_y + 1)
xv, yv = np.meshgrid(xs, ys)
pts = np.stack([xv, yv], axis=-1).reshape(-1, 2) + 0.5
# Barycentric test (vectorized)
v0, v1, v2 = uv0 * (buffer_size - 1), uv1 * (buffer_size - 1), uv2 * (buffer_size - 1)
def sign(p1, p2, p3):
return (p1[..., 0] - p3[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (p1[..., 1] - p3[1])
d1 = sign(pts, v0, v1)
d2 = sign(pts, v1, v2)
d3 = sign(pts, v2, v0)
mask = ~((d1 < 0) | (d2 < 0) | (d3 < 0)) & ~((d1 > 0) & (d2 > 0) & (d3 > 0))
inside_pts = pts[mask]
if len(inside_pts) == 0:
continue
# Barycentric coordinates (vectorized)
def barycentric(p, v0, v1, v2):
v0v1 = v1 - v0
v0v2 = v2 - v0
v0p = p - v0
d00 = np.dot(v0v1, v0v1)
d01 = np.dot(v0v1, v0v2)
d11 = np.dot(v0v2, v0v2)
d20 = np.dot(v0p, v0v1)
d21 = np.dot(v0p, v0v2)
denom = d00 * d11 - d01 * d01
v = (d11 * d20 - d01 * d21) / denom
w = (d00 * d21 - d01 * d20) / denom
u = 1.0 - v - w
return np.clip(u, 0, 1), np.clip(v, 0, 1), np.clip(w, 0, 1)
u, v, w = barycentric(inside_pts, v0, v1, v2)
colors = (u[:, None] * c0 + v[:, None] * c1 + w[:, None] * c2)
xi = inside_pts[:, 0].astype(int)
yi = inside_pts[:, 1].astype(int)
texture_buffer[yi, xi, :3] = np.clip(colors, 0, 255).astype(np.uint8)
texture_buffer[yi, xi, 3] = 255
# Inpainting, filtering, and downsampling
image_bgra = texture_buffer.copy()
mask = (image_bgra[:, :, 3] == 0).astype(np.uint8) * 255
image_bgr = cv2.cvtColor(image_bgra, cv2.COLOR_BGRA2BGR)
inpainted_bgr = cv2.inpaint(image_bgr, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
inpainted_bgra = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2BGRA)
texture_buffer = inpainted_bgra[::-1]
image_texture = Image.fromarray(texture_buffer)
image_texture = image_texture.filter(ImageFilter.MedianFilter(size=3))
image_texture = image_texture.filter(ImageFilter.GaussianBlur(radius=1))
image_texture = image_texture.resize((texture_size, texture_size), Image.LANCZOS)
# Assign UVs and texture to mesh
material = trimesh.visual.material.PBRMaterial(
baseColorFactor=[1.0, 1.0, 1.0, 1.0],
baseColorTexture=image_texture,
metallicFactor=0.0,
roughnessFactor=1.0,
)
visuals = trimesh.visual.TextureVisuals(uv=uvs, material=material)
mesh.visual = visuals
mesh.export(glb_path)
image_texture.save("debug_texture.png")
def generate3d(model, rgb, ccm, device=None):
device = torch.device("cuda")
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type = model.geo_type)
color_tri = torch.from_numpy(rgb).to(device)/255
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)]).to(device)/255
color = color_tri.permute(2,0,1)
xyz = xyz_tri.permute(2,0,1)
def get_imgs(color):
color_list = []
color_list.append(color[:,:,256*5:256*(1+5)])
for i in range(0,5):
color_list.append(color[:,:,256*i:256*(1+i)])
return torch.stack(color_list, dim=0)
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0).to(device)
xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0).to(device)
triplane = torch.cat([color,xyz],dim=1).to(device)
model.eval()
if model.denoising == True:
tnew = 20
tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) *0.5+0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
start_time = time.time()
with torch.no_grad():
triplane_feature2 = model.unet2(triplane,tnew)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"unet takes {elapsed_time}s")
else:
triplane_feature2 = model.unet2(triplane)
with torch.no_grad():
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
start_time = time.time()
with torch.no_grad():
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"uv takes {elapsed_time}s")
obj_path = mesh_path_glb + ".obj"
glb_path = mesh_path_glb + ".glb"
vertex_color_to_uv_textured_glb(obj_path, glb_path)
return glb_path