Commit
·
a14c9ce
1
Parent(s):
2cf3ed2
Refactor generate3d function in inference.py to enhance mesh processing and export logic. Streamlined image tensor handling, improved denoising integration, and updated mesh export to include UV mapping. Ensured proper handling of temporary files for OBJ and GLB formats, enhancing overall readability and maintainability.
Browse files- inference.py +40 -43
inference.py
CHANGED
@@ -1,32 +1,21 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
import time
|
4 |
-
import tempfile
|
5 |
import trimesh
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
9 |
|
10 |
def generate3d(model, rgb, ccm, device):
|
11 |
-
model.renderer = Renderer(
|
12 |
-
|
13 |
-
camera_angle_num=model.camera_angle_num,
|
14 |
-
scale=model.input.scale,
|
15 |
-
geo_type=model.geo_type
|
16 |
-
)
|
17 |
|
18 |
color_tri = torch.from_numpy(rgb) / 255
|
19 |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
|
20 |
-
|
21 |
color = color_tri.permute(2, 0, 1)
|
22 |
xyz = xyz_tri.permute(2, 0, 1)
|
23 |
|
24 |
-
def get_imgs(
|
25 |
-
|
26 |
-
images.append(img_tensor[:, :, 256*5:256*(1+5)])
|
27 |
-
for i in range(5):
|
28 |
-
images.append(img_tensor[:, :, 256*i:256*(i+1)])
|
29 |
-
return torch.stack(images, dim=0) # [6, C, H, W]
|
30 |
|
31 |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
|
32 |
color = get_imgs(color)
|
@@ -34,41 +23,49 @@ def generate3d(model, rgb, ccm, device):
|
|
34 |
|
35 |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
|
36 |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
|
37 |
-
|
38 |
triplane = torch.cat([color, xyz], dim=1).to(device)
|
39 |
|
40 |
model.eval()
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
47 |
triplane_feature2 = model.unet2(triplane, tnew)
|
48 |
-
|
|
|
49 |
triplane_feature2 = model.unet2(triplane)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
with torch.no_grad():
|
52 |
-
data_config = {
|
53 |
-
'resolution': [1024, 1024],
|
54 |
-
"triview_color": triplane_color.to(device),
|
55 |
-
}
|
56 |
verts, faces = model.decode(data_config, triplane_feature2)
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
faces.squeeze().cpu().numpy().astype(np.int32),
|
62 |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
|
63 |
)
|
|
|
|
|
64 |
|
65 |
-
#
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
mesh
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
mesh.
|
|
|
73 |
|
74 |
-
return
|
|
|
1 |
+
from kiui.mesh_utils import clean_mesh
|
|
|
|
|
|
|
2 |
import trimesh
|
3 |
+
import zipfile
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
import nvdiffrast.torch as dr
|
7 |
|
8 |
def generate3d(model, rgb, ccm, device):
|
9 |
+
model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
|
10 |
+
scale=model.input.scale, geo_type=model.geo_type)
|
|
|
|
|
|
|
|
|
11 |
|
12 |
color_tri = torch.from_numpy(rgb) / 255
|
13 |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
|
|
|
14 |
color = color_tri.permute(2, 0, 1)
|
15 |
xyz = xyz_tri.permute(2, 0, 1)
|
16 |
|
17 |
+
def get_imgs(color):
|
18 |
+
return torch.stack([color[:, :, 256 * i:256 * (i + 1)] for i in [5, 0, 1, 2, 3, 4]], dim=0)
|
|
|
|
|
|
|
|
|
19 |
|
20 |
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
|
21 |
color = get_imgs(color)
|
|
|
23 |
|
24 |
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
|
25 |
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
|
|
|
26 |
triplane = torch.cat([color, xyz], dim=1).to(device)
|
27 |
|
28 |
model.eval()
|
29 |
+
if model.denoising:
|
30 |
+
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
|
31 |
+
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
|
32 |
+
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
33 |
+
with torch.no_grad():
|
|
|
34 |
triplane_feature2 = model.unet2(triplane, tnew)
|
35 |
+
else:
|
36 |
+
with torch.no_grad():
|
37 |
triplane_feature2 = model.unet2(triplane)
|
38 |
|
39 |
+
data_config = {
|
40 |
+
'resolution': [1024, 1024],
|
41 |
+
"triview_color": triplane_color.to(device),
|
42 |
+
}
|
43 |
+
|
44 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
45 |
verts, faces = model.decode(data_config, triplane_feature2)
|
46 |
+
data_config['verts'] = verts[0]
|
47 |
+
data_config['faces'] = faces
|
48 |
|
49 |
+
verts, faces = clean_mesh(
|
50 |
+
data_config['verts'].squeeze().cpu().numpy().astype(np.float32),
|
51 |
+
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
|
|
|
52 |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
|
53 |
)
|
54 |
+
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
|
55 |
+
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
|
56 |
|
57 |
+
# Create base filename
|
58 |
+
temp_path = tempfile.NamedTemporaryFile(suffix="", delete=False).name
|
59 |
+
obj_base = temp_path # no extension
|
60 |
|
61 |
+
# Export mesh with UV and PNG
|
62 |
+
glctx = dr.RasterizeCudaContext()
|
63 |
+
model.export_mesh_wt_uv(
|
64 |
+
glctx, data_config, obj_base, "", device, res=(1024, 1024), tri_fea_2=triplane_feature2
|
65 |
+
)
|
66 |
|
67 |
+
# Load .obj with texture and export .glb
|
68 |
+
mesh = trimesh.load(obj_base + ".obj", process=False)
|
69 |
+
mesh.export(obj_base + ".glb")
|
70 |
|
71 |
+
return obj_base + ".glb"
|