Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Update imports in inference.py to include numpy, PIL, and mesh utilities for enhanced functionality and improved mesh processing.
d63ad23
raw
history blame
2.72 kB
import numpy as np
import torch
import time
import tempfile
import os
from PIL import Image
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
scale=model.input.scale, geo_type=model.geo_type)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
# Create base filename
temp_path = tempfile.NamedTemporaryFile(suffix="", delete=False).name
obj_base = temp_path # no extension
# Export mesh with UV and PNG
glctx = dr.RasterizeCudaContext()
model.export_mesh_wt_uv(
glctx, data_config, obj_base, "", device, res=(1024, 1024), tri_fea_2=triplane_feature2
)
# Load .obj with texture and export .glb
mesh = trimesh.load(obj_base + ".obj", process=False)
mesh.export(obj_base + ".glb")
return obj_base + ".glb"