Refactor generate3d function in inference.py to improve mesh export process. Introduced temporary file handling for OBJ, MTL, and PNG outputs, ensuring proper texture mapping. Enhanced code readability by restructuring renderer initialization and file writing logic.
e5c94c9
import numpy as np | |
import torch | |
import time | |
from util.utils import get_tri | |
import tempfile | |
from util.renderer import Renderer | |
import os | |
from PIL import Image | |
def generate3d(model, rgb, ccm, device): | |
model.renderer = Renderer( | |
tet_grid_size=model.tet_grid_size, | |
camera_angle_num=model.camera_angle_num, | |
scale=model.input.scale, | |
geo_type=model.geo_type | |
) | |
color_tri = torch.from_numpy(rgb) / 255 | |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255 | |
color = color_tri.permute(2, 0, 1) | |
xyz = xyz_tri.permute(2, 0, 1) | |
def get_imgs(color): | |
color_list = [] | |
color_list.append(color[:, :, 256 * 5:256 * (1 + 5)]) | |
for i in range(0, 5): | |
color_list.append(color[:, :, 256 * i:256 * (1 + i)]) | |
return torch.stack(color_list, dim=0) | |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device) | |
color = get_imgs(color) | |
xyz = get_imgs(xyz) | |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0) | |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0) | |
triplane = torch.cat([color, xyz], dim=1).to(device) | |
model.eval() | |
if model.denoising: | |
tnew = 20 | |
tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device) | |
noise_new = torch.randn_like(triplane) * 0.5 + 0.5 | |
triplane = model.scheduler.add_noise(triplane, noise_new, tnew) | |
with torch.no_grad(): | |
triplane_feature2 = model.unet2(triplane, tnew) | |
else: | |
triplane_feature2 = model.unet2(triplane) | |
with torch.no_grad(): | |
data_config = { | |
'resolution': [1024, 1024], | |
"triview_color": triplane_color.to(device), | |
} | |
verts, faces = model.decode(data_config, triplane_feature2) | |
data_config['verts'] = verts[0] | |
data_config['faces'] = faces | |
from kiui.mesh_utils import clean_mesh | |
verts, faces = clean_mesh( | |
data_config['verts'].squeeze().cpu().numpy().astype(np.float32), | |
data_config['faces'].squeeze().cpu().numpy().astype(np.int32), | |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1 | |
) | |
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous() | |
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous() | |
# === Export OBJ/MTL/PNG === | |
obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name | |
base_path = obj_path[:-4] # remove .obj | |
texture_path = base_path + ".png" | |
mtl_path = base_path + ".mtl" | |
model.export_mesh_geometry(data_config, obj_path) # writes .obj with UVs | |
model.export_texture_image(data_config, texture_path) # saves PNG | |
# Write MTL file manually | |
with open(mtl_path, "w") as f: | |
f.write("newmtl material0\n") | |
f.write("Kd 1.000000 1.000000 1.000000\n") | |
f.write(f"map_Kd {os.path.basename(texture_path)}\n") | |
# Append .mtl reference to OBJ file | |
with open(obj_path, "r") as original: | |
lines = original.readlines() | |
with open(obj_path, "w") as modified: | |
modified.write(f"mtllib {os.path.basename(mtl_path)}\n") | |
modified.writelines(lines) | |
return obj_path |