Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
debug
76eeb7d
raw
history blame
3.83 kB
import numpy as np
import torch
import time
import tempfile
import os
from PIL import Image
import zipfile
import trimesh
from util.utils import get_tri
from mesh import Mesh
from util.renderer import Renderer
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(
tet_grid_size=model.tet_grid_size,
camera_angle_num=model.camera_angle_num,
scale=model.input.scale,
geo_type=model.geo_type
)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(img_tensor):
images = []
images.append(img_tensor[:, :, 256*5:256*(1+5)])
for i in range(5):
images.append(img_tensor[:, :, 256*i:256*(i+1)])
return torch.stack(images, dim=0) # [6, C, H, W]
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
with torch.no_grad():
triplane_feature2 = model.unet2(triplane)
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
with torch.no_grad():
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
# Clean the mesh
from kiui.mesh_utils import clean_mesh
verts, faces = clean_mesh(
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
# Export mesh with UV
mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
model.export_mesh_wt_uv(
None, # GL context not needed for CPU export
data_config,
mesh_path_obj,
"",
device,
res=(1024, 1024),
tri_fea_2=triplane_feature2
)
# Check if texture exists
texture_path = mesh_path_obj + ".png"
if not os.path.exists(texture_path):
raise RuntimeError("Texture image not created, cannot export textured GLB.")
# Load the .obj file and apply the texture manually
scene_or_mesh = trimesh.load(mesh_path_obj + ".obj", force='scene')
texture_image = Image.open(texture_path)
texture_image = np.array(texture_image)
mesh_path_glb = mesh_path_obj + "_textured.glb"
if isinstance(scene_or_mesh, trimesh.Scene):
for _, geometry in scene_or_mesh.geometry.items():
material = trimesh.visual.texture.SimpleMaterial(image=texture_image)
geometry.visual = trimesh.visual.texture.TextureVisuals(image=texture_image, material=material)
scene_or_mesh.export(mesh_path_glb)
else:
material = trimesh.visual.texture.SimpleMaterial(image=texture_image)
scene_or_mesh.visual = trimesh.visual.texture.TextureVisuals(image=texture_image, material=material)
scene_or_mesh.export(mesh_path_glb)
return mesh_path_glb