Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
f6ac519
·
1 Parent(s): e666402

Refactor generate3d function in inference.py to implement CPU-only UV unwrapping using xatlas and trimesh for improved texture mapping. Updated texture baking logic to utilize the model's decoder and rgbMlp, enhancing the quality of generated textures in mesh exports.

Browse files
Files changed (1) hide show
  1. inference.py +30 -31
inference.py CHANGED
@@ -6,6 +6,7 @@ import zipfile
6
  import nvdiffrast.torch as dr
7
  import xatlas
8
  import cv2
 
9
 
10
  from util.utils import get_tri
11
  from mesh import Mesh
@@ -62,36 +63,34 @@ def generate3d(model, rgb, ccm, device):
62
  data_config['verts'] = torch.from_numpy(verts).contiguous()
63
  data_config['faces'] = torch.from_numpy(faces).contiguous()
64
 
65
- # --- CCM-based UV assignment ---
66
- mesh_v = data_config['verts'].cpu().numpy() # [N, 3]
67
- mesh_f = data_config['faces'].cpu().numpy() # [M, 3]
68
-
69
- # Prepare CCM and color map
70
- ccm_img = ccm.astype(np.uint8) if ccm.max() > 1 else (ccm * 255).astype(np.uint8)
71
- if ccm_img.shape[-1] != 3:
72
- ccm_img = np.transpose(ccm_img, (1, 2, 0))
73
- H, W, _ = ccm_img.shape
74
-
75
- color_map = rgb.astype(np.uint8) if rgb.max() > 1 else (rgb * 255).astype(np.uint8)
76
- if color_map.shape[-1] != 3:
77
- color_map = np.transpose(color_map, (1, 2, 0))
78
- albedo = cv2.cvtColor(color_map, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
79
-
80
- # Project mesh vertices to CCM image space and get UVs
81
- vt = []
82
- for v in mesh_v:
83
- # Assume mesh is in [-1,1] in x/y, project to CCM image
84
- x, y, z = v
85
- u_img = int((x + 1) / 2 * (W - 1))
86
- v_img = int((y + 1) / 2 * (H - 1))
87
- u_img = np.clip(u_img, 0, W-1)
88
- v_img = np.clip(v_img, 0, H-1)
89
- r, g, b = ccm_img[v_img, u_img]
90
- u = r / 255.0
91
- v_ = g / 255.0
92
- vt.append([u, v_])
93
- vt = np.array(vt, dtype=np.float32)
94
- ft = mesh_f.copy()
95
 
96
  # Create Mesh and export .glb
97
  mesh = Mesh(
@@ -99,7 +98,7 @@ def generate3d(model, rgb, ccm, device):
99
  f=torch.from_numpy(mesh_f).int(),
100
  vt=torch.from_numpy(vt).float(),
101
  ft=torch.from_numpy(ft).int(),
102
- albedo=torch.from_numpy(albedo).float()
103
  )
104
  temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
105
  mesh.write(temp_path)
 
6
  import nvdiffrast.torch as dr
7
  import xatlas
8
  import cv2
9
+ import trimesh
10
 
11
  from util.utils import get_tri
12
  from mesh import Mesh
 
63
  data_config['verts'] = torch.from_numpy(verts).contiguous()
64
  data_config['faces'] = torch.from_numpy(faces).contiguous()
65
 
66
+ # CPU-only UV unwrapping with xatlas
67
+ mesh_v = data_config['verts'].cpu().numpy()
68
+ mesh_f = data_config['faces'].cpu().numpy()
69
+ vmapping, ft, vt = xatlas.parametrize(mesh_v, mesh_f)
70
+
71
+ # Bake texture using model's decoder and rgbMlp (CPU-only)
72
+ tex_res = (1024, 1024)
73
+ # Generate a grid of UV coordinates
74
+ uv_grid = np.stack(np.meshgrid(
75
+ np.linspace(0, 1, tex_res[0]),
76
+ np.linspace(0, 1, tex_res[1])
77
+ ), -1).reshape(-1, 2) # (H*W, 2)
78
+ # Map UVs to 3D positions using barycentric interpolation of mesh faces
79
+ # For simplicity, we'll sample random points on the mesh surface and use their UVs
80
+ # (A more advanced approach would rasterize each face, but this is a CPU-friendly approximation)
81
+ mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
82
+ points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
83
+ # Use the model's decoder and rgbMlp to get color for each sampled point
84
+ points_tensor = torch.from_numpy(points).float().unsqueeze(0) # (1, N, 3)
85
+ with torch.no_grad():
86
+ dec_verts = model.decoder(triplane_feature2, points_tensor)
87
+ colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)
88
+ colors = (colors * 0.5 + 0.5).clip(0, 1)
89
+ # Fill the texture image
90
+ texture = np.zeros((tex_res[1]*tex_res[0], 3), dtype=np.float32)
91
+ texture[:colors.shape[0]] = colors
92
+ texture = texture.reshape(tex_res[1], tex_res[0], 3)
93
+ texture = np.clip(texture, 0, 1)
 
 
94
 
95
  # Create Mesh and export .glb
96
  mesh = Mesh(
 
98
  f=torch.from_numpy(mesh_f).int(),
99
  vt=torch.from_numpy(vt).float(),
100
  ft=torch.from_numpy(ft).int(),
101
+ albedo=torch.from_numpy(texture).float()
102
  )
103
  temp_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False).name
104
  mesh.write(temp_path)