Commit
·
e5c94c9
1
Parent(s):
b6dcd98
Refactor generate3d function in inference.py to improve mesh export process. Introduced temporary file handling for OBJ, MTL, and PNG outputs, ensuring proper texture mapping. Enhanced code readability by restructuring renderer initialization and file writing logic.
Browse files- inference.py +30 -10
inference.py
CHANGED
@@ -1,16 +1,20 @@
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import time
|
4 |
-
import nvdiffrast.torch as dr
|
5 |
from util.utils import get_tri
|
6 |
import tempfile
|
7 |
from util.renderer import Renderer
|
8 |
import os
|
|
|
9 |
|
10 |
-
def generate3d(model, rgb, ccm, device):
|
11 |
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
color_tri = torch.from_numpy(rgb) / 255
|
16 |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
|
@@ -64,11 +68,27 @@ def generate3d(model, rgb, ccm, device):
|
|
64 |
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
|
65 |
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
|
66 |
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
73 |
|
74 |
-
return
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import time
|
|
|
4 |
from util.utils import get_tri
|
5 |
import tempfile
|
6 |
from util.renderer import Renderer
|
7 |
import os
|
8 |
+
from PIL import Image
|
9 |
|
|
|
10 |
|
11 |
+
def generate3d(model, rgb, ccm, device):
|
12 |
+
model.renderer = Renderer(
|
13 |
+
tet_grid_size=model.tet_grid_size,
|
14 |
+
camera_angle_num=model.camera_angle_num,
|
15 |
+
scale=model.input.scale,
|
16 |
+
geo_type=model.geo_type
|
17 |
+
)
|
18 |
|
19 |
color_tri = torch.from_numpy(rgb) / 255
|
20 |
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
|
|
|
68 |
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
|
69 |
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
|
70 |
|
71 |
+
# === Export OBJ/MTL/PNG ===
|
72 |
+
obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
|
73 |
+
base_path = obj_path[:-4] # remove .obj
|
74 |
+
|
75 |
+
texture_path = base_path + ".png"
|
76 |
+
mtl_path = base_path + ".mtl"
|
77 |
+
|
78 |
+
model.export_mesh_geometry(data_config, obj_path) # writes .obj with UVs
|
79 |
+
model.export_texture_image(data_config, texture_path) # saves PNG
|
80 |
+
|
81 |
+
# Write MTL file manually
|
82 |
+
with open(mtl_path, "w") as f:
|
83 |
+
f.write("newmtl material0\n")
|
84 |
+
f.write("Kd 1.000000 1.000000 1.000000\n")
|
85 |
+
f.write(f"map_Kd {os.path.basename(texture_path)}\n")
|
86 |
|
87 |
+
# Append .mtl reference to OBJ file
|
88 |
+
with open(obj_path, "r") as original:
|
89 |
+
lines = original.readlines()
|
90 |
+
with open(obj_path, "w") as modified:
|
91 |
+
modified.write(f"mtllib {os.path.basename(mtl_path)}\n")
|
92 |
+
modified.writelines(lines)
|
93 |
|
94 |
+
return obj_path
|