Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
e5c94c9
·
1 Parent(s): b6dcd98

Refactor generate3d function in inference.py to improve mesh export process. Introduced temporary file handling for OBJ, MTL, and PNG outputs, ensuring proper texture mapping. Enhanced code readability by restructuring renderer initialization and file writing logic.

Browse files
Files changed (1) hide show
  1. inference.py +30 -10
inference.py CHANGED
@@ -1,16 +1,20 @@
1
  import numpy as np
2
  import torch
3
  import time
4
- import nvdiffrast.torch as dr
5
  from util.utils import get_tri
6
  import tempfile
7
  from util.renderer import Renderer
8
  import os
 
9
 
10
- def generate3d(model, rgb, ccm, device):
11
 
12
- model.renderer = Renderer(tet_grid_size=model.tet_grid_size, camera_angle_num=model.camera_angle_num,
13
- scale=model.input.scale, geo_type=model.geo_type)
 
 
 
 
 
14
 
15
  color_tri = torch.from_numpy(rgb) / 255
16
  xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
@@ -64,11 +68,27 @@ def generate3d(model, rgb, ccm, device):
64
  data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
65
  data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
66
 
67
- with torch.no_grad():
68
- mesh_path_base = tempfile.NamedTemporaryFile(suffix="", delete=False).name
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Export mesh with UV, texture, and MTL
71
- ctx = dr.RasterizeCudaContext(device=device)
72
- model.export_mesh_wt_uv(ctx, data_config, mesh_path_base, ind=0, device=device, res=(1024, 1024), tri_fea_2=triplane_feature2)
 
 
 
73
 
74
- return mesh_path_base + ".obj"
 
1
  import numpy as np
2
  import torch
3
  import time
 
4
  from util.utils import get_tri
5
  import tempfile
6
  from util.renderer import Renderer
7
  import os
8
+ from PIL import Image
9
 
 
10
 
11
+ def generate3d(model, rgb, ccm, device):
12
+ model.renderer = Renderer(
13
+ tet_grid_size=model.tet_grid_size,
14
+ camera_angle_num=model.camera_angle_num,
15
+ scale=model.input.scale,
16
+ geo_type=model.geo_type
17
+ )
18
 
19
  color_tri = torch.from_numpy(rgb) / 255
20
  xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
 
68
  data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
69
  data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
70
 
71
+ # === Export OBJ/MTL/PNG ===
72
+ obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
73
+ base_path = obj_path[:-4] # remove .obj
74
+
75
+ texture_path = base_path + ".png"
76
+ mtl_path = base_path + ".mtl"
77
+
78
+ model.export_mesh_geometry(data_config, obj_path) # writes .obj with UVs
79
+ model.export_texture_image(data_config, texture_path) # saves PNG
80
+
81
+ # Write MTL file manually
82
+ with open(mtl_path, "w") as f:
83
+ f.write("newmtl material0\n")
84
+ f.write("Kd 1.000000 1.000000 1.000000\n")
85
+ f.write(f"map_Kd {os.path.basename(texture_path)}\n")
86
 
87
+ # Append .mtl reference to OBJ file
88
+ with open(obj_path, "r") as original:
89
+ lines = original.readlines()
90
+ with open(obj_path, "w") as modified:
91
+ modified.write(f"mtllib {os.path.basename(mtl_path)}\n")
92
+ modified.writelines(lines)
93
 
94
+ return obj_path