Commit
·
dfab55e
1
Parent(s):
b2dc9cc
Refactor generate3d function in inference.py to improve readability and maintainability. Enhanced RGB and coordinate conversion, streamlined noise addition for denoising, and updated mesh export process to utilize trimesh for GLB format, ensuring proper handling of UV textures.
Browse files- inference.py +55 -63
inference.py
CHANGED
@@ -7,93 +7,85 @@ import tempfile
|
|
7 |
from mesh import Mesh
|
8 |
import zipfile
|
9 |
from util.renderer import Renderer
|
10 |
-
|
11 |
-
|
12 |
-
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
|
13 |
-
scale=model.input.scale, geo_type = model.geo_type)
|
14 |
-
|
15 |
-
color_tri = torch.from_numpy(rgb)/255
|
16 |
-
xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
|
17 |
-
color = color_tri.permute(2,0,1)
|
18 |
-
xyz = xyz_tri.permute(2,0,1)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def get_imgs(color):
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
color_list.append(color[:,:,256*i:256*(1+i)])
|
27 |
-
return torch.stack(color_list, dim=0)# [6, C, H, W]
|
28 |
-
|
29 |
-
triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
|
30 |
|
|
|
31 |
color = get_imgs(color)
|
32 |
xyz = get_imgs(xyz)
|
33 |
|
34 |
-
color = get_tri(color, dim=0, blender=
|
35 |
-
xyz = get_tri(xyz, dim=0, blender=
|
|
|
|
|
36 |
|
37 |
-
triplane = torch.cat([color,xyz],dim=1).to(device)
|
38 |
-
# 3D visualize
|
39 |
model.eval()
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
noise_new = torch.randn_like(triplane) *0.5+0.5
|
46 |
-
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
47 |
-
start_time = time.time()
|
48 |
with torch.no_grad():
|
49 |
-
triplane_feature2 = model.unet2(triplane,tnew)
|
50 |
-
end_time = time.time()
|
51 |
-
elapsed_time = end_time - start_time
|
52 |
-
print(f"unet takes {elapsed_time}s")
|
53 |
else:
|
54 |
triplane_feature2 = model.unet2(triplane)
|
55 |
-
|
56 |
|
57 |
with torch.no_grad():
|
58 |
data_config = {
|
59 |
'resolution': [1024, 1024],
|
60 |
-
|
61 |
}
|
62 |
|
63 |
verts, faces = model.decode(data_config, triplane_feature2)
|
64 |
-
|
65 |
data_config['verts'] = verts[0]
|
66 |
data_config['faces'] = faces
|
67 |
-
|
68 |
|
|
|
69 |
from kiui.mesh_utils import clean_mesh
|
70 |
-
verts, faces = clean_mesh(
|
|
|
|
|
|
|
|
|
71 |
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
|
72 |
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
with torch.no_grad():
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
# # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
|
89 |
-
# # mesh_obj2.export(mesh_path_obj2+".obj")
|
90 |
-
|
91 |
-
# with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
|
92 |
-
# myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
|
93 |
-
# myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
|
94 |
-
# myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')
|
95 |
-
|
96 |
-
end_time = time.time()
|
97 |
-
elapsed_time = end_time - start_time
|
98 |
-
print(f"uv takes {elapsed_time}s")
|
99 |
-
return mesh_path_glb+".glb"
|
|
|
7 |
from mesh import Mesh
|
8 |
import zipfile
|
9 |
from util.renderer import Renderer
|
10 |
+
import trimesh # Needed for glb export
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def generate3d(model, rgb, ccm, device):
|
13 |
+
model.renderer = Renderer(
|
14 |
+
tet_grid_size=model.tet_grid_size,
|
15 |
+
camera_angle_num=model.camera_angle_num,
|
16 |
+
scale=model.input.scale,
|
17 |
+
geo_type=model.geo_type
|
18 |
+
)
|
19 |
+
|
20 |
+
# RGB and coordinate conversion
|
21 |
+
color_tri = torch.from_numpy(rgb) / 255
|
22 |
+
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
|
23 |
+
color = color_tri.permute(2, 0, 1)
|
24 |
+
xyz = xyz_tri.permute(2, 0, 1)
|
25 |
|
26 |
def get_imgs(color):
|
27 |
+
color_list = [color[:, :, 256 * 5:256 * (1 + 5)]]
|
28 |
+
for i in range(0, 5):
|
29 |
+
color_list.append(color[:, :, 256 * i:256 * (1 + i)])
|
30 |
+
return torch.stack(color_list, dim=0)
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
|
33 |
color = get_imgs(color)
|
34 |
xyz = get_imgs(xyz)
|
35 |
|
36 |
+
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
|
37 |
+
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
|
38 |
+
|
39 |
+
triplane = torch.cat([color, xyz], dim=1).to(device)
|
40 |
|
|
|
|
|
41 |
model.eval()
|
42 |
+
|
43 |
+
if model.denoising:
|
44 |
+
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
|
45 |
+
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
|
46 |
+
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
|
|
|
|
|
|
47 |
with torch.no_grad():
|
48 |
+
triplane_feature2 = model.unet2(triplane, tnew)
|
|
|
|
|
|
|
49 |
else:
|
50 |
triplane_feature2 = model.unet2(triplane)
|
|
|
51 |
|
52 |
with torch.no_grad():
|
53 |
data_config = {
|
54 |
'resolution': [1024, 1024],
|
55 |
+
'triview_color': triplane_color.to(device),
|
56 |
}
|
57 |
|
58 |
verts, faces = model.decode(data_config, triplane_feature2)
|
|
|
59 |
data_config['verts'] = verts[0]
|
60 |
data_config['faces'] = faces
|
|
|
61 |
|
62 |
+
# Optional mesh cleanup (reduce remesh for speed)
|
63 |
from kiui.mesh_utils import clean_mesh
|
64 |
+
verts, faces = clean_mesh(
|
65 |
+
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
|
66 |
+
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
|
67 |
+
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
|
68 |
+
)
|
69 |
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
|
70 |
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
|
71 |
|
72 |
+
# Rasterization context
|
73 |
+
glctx = dr.RasterizeGLContext()
|
74 |
+
|
75 |
+
# Temporary output path
|
76 |
+
mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
|
77 |
+
|
78 |
+
# Export OBJ with UV and PNG
|
79 |
with torch.no_grad():
|
80 |
+
model.export_mesh_wt_uv(
|
81 |
+
glctx, data_config, mesh_path_obj, "", device,
|
82 |
+
res=(512, 512), tri_fea_2=triplane_feature2
|
83 |
+
)
|
84 |
+
|
85 |
+
# Convert to .glb using trimesh
|
86 |
+
mesh = trimesh.load(mesh_path_obj + ".obj", force='mesh')
|
87 |
+
mesh_path_glb = mesh_path_obj + ".glb"
|
88 |
+
mesh.export(mesh_path_glb, file_type='glb')
|
89 |
+
|
90 |
+
print(f"✅ Exported GLB with UV texture: {mesh_path_glb}")
|
91 |
+
return mesh_path_glb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|