Commit
·
2cf3ed2
1
Parent(s):
30b6055
Refactor generate3d function in inference.py to streamline noise addition and improve mesh processing. Enhanced color generation for vertices and updated GLB export logic, ensuring efficient handling of mesh data and improved readability.
Browse files- inference.py +19 -47
inference.py
CHANGED
@@ -2,12 +2,9 @@ import numpy as np
|
|
2 |
import torch
|
3 |
import time
|
4 |
import tempfile
|
5 |
-
import os
|
6 |
-
from PIL import Image
|
7 |
import trimesh
|
8 |
|
9 |
from util.utils import get_tri
|
10 |
-
from mesh import Mesh
|
11 |
from util.renderer import Renderer
|
12 |
|
13 |
def generate3d(model, rgb, ccm, device):
|
@@ -42,61 +39,36 @@ def generate3d(model, rgb, ccm, device):
|
|
42 |
|
43 |
model.eval()
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
triplane_feature2 = model.unet2(triplane, tnew)
|
51 |
-
|
52 |
-
with torch.no_grad():
|
53 |
triplane_feature2 = model.unet2(triplane)
|
54 |
|
55 |
-
data_config = {
|
56 |
-
'resolution': [1024, 1024],
|
57 |
-
"triview_color": triplane_color.to(device),
|
58 |
-
}
|
59 |
-
|
60 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
61 |
verts, faces = model.decode(data_config, triplane_feature2)
|
62 |
-
data_config['verts'] = verts[0]
|
63 |
-
data_config['faces'] = faces
|
64 |
|
65 |
from kiui.mesh_utils import clean_mesh
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
|
70 |
)
|
71 |
-
data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
|
72 |
-
data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
|
73 |
-
|
74 |
-
# Export to .obj with UV and texture image
|
75 |
-
mesh_path_obj = tempfile.NamedTemporaryFile(suffix="", delete=False).name
|
76 |
-
model.export_mesh_wt_uv(
|
77 |
-
None, # GL context skipped (no dr.RasterizeGLContext)
|
78 |
-
data_config,
|
79 |
-
mesh_path_obj,
|
80 |
-
"",
|
81 |
-
device,
|
82 |
-
res=(1024, 1024),
|
83 |
-
tri_fea_2=triplane_feature2
|
84 |
-
)
|
85 |
-
|
86 |
-
# Convert to GLB with embedded texture
|
87 |
-
obj_path = mesh_path_obj + ".obj"
|
88 |
-
texture_path = mesh_path_obj + ".png"
|
89 |
-
|
90 |
-
if not os.path.exists(obj_path) or not os.path.exists(texture_path):
|
91 |
-
raise RuntimeError("OBJ or texture file missing — cannot convert to GLB.")
|
92 |
|
93 |
-
|
94 |
-
|
|
|
95 |
|
96 |
-
|
97 |
-
geom.visual = trimesh.visual.texture.TextureVisuals(image=texture_image)
|
98 |
|
99 |
-
glb_path =
|
100 |
-
|
101 |
|
102 |
return glb_path
|
|
|
2 |
import torch
|
3 |
import time
|
4 |
import tempfile
|
|
|
|
|
5 |
import trimesh
|
6 |
|
7 |
from util.utils import get_tri
|
|
|
8 |
from util.renderer import Renderer
|
9 |
|
10 |
def generate3d(model, rgb, ccm, device):
|
|
|
39 |
|
40 |
model.eval()
|
41 |
|
42 |
+
with torch.no_grad():
|
43 |
+
if model.denoising:
|
44 |
+
tnew = torch.randint(20, 21, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
|
45 |
+
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
|
46 |
+
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
|
47 |
triplane_feature2 = model.unet2(triplane, tnew)
|
48 |
+
else:
|
|
|
49 |
triplane_feature2 = model.unet2(triplane)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
with torch.no_grad():
|
52 |
+
data_config = {
|
53 |
+
'resolution': [1024, 1024],
|
54 |
+
"triview_color": triplane_color.to(device),
|
55 |
+
}
|
56 |
verts, faces = model.decode(data_config, triplane_feature2)
|
|
|
|
|
57 |
|
58 |
from kiui.mesh_utils import clean_mesh
|
59 |
+
verts_np, faces_np = clean_mesh(
|
60 |
+
verts[0].squeeze().cpu().numpy().astype(np.float32),
|
61 |
+
faces.squeeze().cpu().numpy().astype(np.int32),
|
62 |
repair=False, remesh=True, remesh_size=0.005, remesh_iters=1
|
63 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
# === Generate per-vertex color (approximate)
|
66 |
+
colors = np.tile(np.mean(rgb.reshape(-1, 3), axis=0, keepdims=True), (verts_np.shape[0], 1)) / 255.0
|
67 |
+
# Optionally, use more sophisticated color mapping logic if you have UVs
|
68 |
|
69 |
+
mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_np, vertex_colors=colors, process=False)
|
|
|
70 |
|
71 |
+
glb_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
|
72 |
+
mesh.export(glb_path)
|
73 |
|
74 |
return glb_path
|