Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Change device setting for mesh data in generate3d function to CPU for consistency with recent updates, ensuring compatibility across systems without GPU support.
82982c9
raw
history blame
6.89 kB
import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from util.utils import get_tri
import tempfile
from mesh import Mesh
import zipfile
from util.renderer import Renderer
import trimesh
import xatlas
import cv2
from PIL import Image, ImageFilter
def vertex_color_to_uv_textured_glb(obj_path, glb_path, texture_size=640):
mesh = trimesh.load(obj_path, process=False)
vertex_colors = mesh.visual.vertex_colors[:, :3] # (N, 3), uint8
# Generate UVs
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
vertices = mesh.vertices[vmapping]
vertex_colors = vertex_colors[vmapping]
mesh.vertices = vertices
mesh.faces = indices
# Bake texture
buffer_size = texture_size * 2
texture_buffer = np.zeros((buffer_size, buffer_size, 4), dtype=np.uint8)
def barycentric_interpolate(v0, v1, v2, c0, c1, c2, p):
v0v1 = v1 - v0
v0v2 = v2 - v0
v0p = p - v0
d00 = np.dot(v0v1, v0v1)
d01 = np.dot(v0v1, v0v2)
d11 = np.dot(v0v2, v0v2)
d20 = np.dot(v0p, v0v1)
d21 = np.dot(v0p, v0v2)
denom = d00 * d11 - d01 * d01
if abs(denom) < 1e-8:
return (c0 + c1 + c2) / 3
v = (d11 * d20 - d01 * d21) / denom
w = (d00 * d21 - d01 * d20) / denom
u = 1.0 - v - w
u = np.clip(u, 0, 1)
v = np.clip(v, 0, 1)
w = np.clip(w, 0, 1)
return u * c0 + v * c1 + w * c2
def is_point_in_triangle(p, v0, v1, v2):
def sign(p1, p2, p3):
return (p1[0] - p3[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (p1[1] - p3[1])
d1 = sign(p, v0, v1)
d2 = sign(p, v1, v2)
d3 = sign(p, v2, v0)
has_neg = (d1 < 0) or (d2 < 0) or (d3 < 0)
has_pos = (d1 > 0) or (d2 > 0) or (d3 > 0)
return not (has_neg and has_pos)
for face in mesh.faces:
uv0, uv1, uv2 = uvs[face]
c0, c1, c2 = vertex_colors[face]
uv0 = (uv0 * (buffer_size - 1)).astype(int)
uv1 = (uv1 * (buffer_size - 1)).astype(int)
uv2 = (uv2 * (buffer_size - 1)).astype(int)
min_x = max(int(np.floor(min(uv0[0], uv1[0], uv2[0]))), 0)
max_x = min(int(np.ceil(max(uv0[0], uv1[0], uv2[0]))), buffer_size - 1)
min_y = max(int(np.floor(min(uv0[1], uv1[1], uv2[1]))), 0)
max_y = min(int(np.ceil(max(uv0[1], uv1[1], uv2[1]))), buffer_size - 1)
for y in range(min_y, max_y + 1):
for x in range(min_x, max_x + 1):
p = np.array([x + 0.5, y + 0.5])
if is_point_in_triangle(p, uv0, uv1, uv2):
color = barycentric_interpolate(uv0, uv1, uv2, c0, c1, c2, p)
texture_buffer[y, x, :3] = np.clip(color, 0, 255).astype(np.uint8)
texture_buffer[y, x, 3] = 255
# Inpainting, filtering, and downsampling
image_bgra = texture_buffer.copy()
mask = (image_bgra[:, :, 3] == 0).astype(np.uint8) * 255
image_bgr = cv2.cvtColor(image_bgra, cv2.COLOR_BGRA2BGR)
inpainted_bgr = cv2.inpaint(image_bgr, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
inpainted_bgra = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2BGRA)
texture_buffer = inpainted_bgra[::-1]
image_texture = Image.fromarray(texture_buffer)
image_texture = image_texture.filter(ImageFilter.MedianFilter(size=3))
image_texture = image_texture.filter(ImageFilter.GaussianBlur(radius=1))
image_texture = image_texture.resize((texture_size, texture_size), Image.LANCZOS)
# Assign UVs and texture to mesh
material = trimesh.visual.material.PBRMaterial(
baseColorFactor=[1.0, 1.0, 1.0, 1.0],
baseColorTexture=image_texture,
metallicFactor=0.0,
roughnessFactor=1.0,
)
visuals = trimesh.visual.TextureVisuals(uv=uvs, material=material)
mesh.visual = visuals
mesh.export(glb_path)
image_texture.save("debug_texture.png")
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type = model.geo_type)
color_tri = torch.from_numpy(rgb)/255
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
color = color_tri.permute(2,0,1)
xyz = xyz_tri.permute(2,0,1)
def get_imgs(color):
# color : [C, H, W*6]
color_list = []
color_list.append(color[:,:,256*5:256*(1+5)])
for i in range(0,5):
color_list.append(color[:,:,256*i:256*(1+i)])
return torch.stack(color_list, dim=0)# [6, C, H, W]
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
triplane = torch.cat([color,xyz],dim=1).to(device)
# 3D visualize
model.eval()
if model.denoising == True:
tnew = 20
tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) *0.5+0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
start_time = time.time()
with torch.no_grad():
triplane_feature2 = model.unet2(triplane,tnew)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"unet takes {elapsed_time}s")
else:
triplane_feature2 = model.unet2(triplane)
with torch.no_grad():
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=True, remesh_size=0.005, remesh_iters=1)
data_config['verts'] = torch.from_numpy(verts).to('cpu').contiguous()
data_config['faces'] = torch.from_numpy(faces).to('cpu').contiguous()
start_time = time.time()
with torch.no_grad():
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
model.export_mesh(data_config, mesh_path_glb, tri_fea_2 = triplane_feature2)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"uv takes {elapsed_time}s")
# Convert .obj (with vertex colors) to UV-mapped textured .glb
obj_path = mesh_path_glb + ".obj"
glb_path = mesh_path_glb + ".glb"
vertex_color_to_uv_textured_glb(obj_path, glb_path)
return glb_path