Spaces:
mashroo
/
Runtime error

YoussefAnso commited on
Commit
dcb102c
·
1 Parent(s): f6ac519

Ensure tensors are on CPU in generate3d function of inference.py for consistent processing. Updated points_tensor and triplane_feature2 to prevent device-related issues during inference.

Browse files
Files changed (1) hide show
  1. inference.py +2 -1
inference.py CHANGED
@@ -81,7 +81,8 @@ def generate3d(model, rgb, ccm, device):
81
  mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
82
  points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
83
  # Use the model's decoder and rgbMlp to get color for each sampled point
84
- points_tensor = torch.from_numpy(points).float().unsqueeze(0) # (1, N, 3)
 
85
  with torch.no_grad():
86
  dec_verts = model.decoder(triplane_feature2, points_tensor)
87
  colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)
 
81
  mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
82
  points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
83
  # Use the model's decoder and rgbMlp to get color for each sampled point
84
+ points_tensor = torch.from_numpy(points).float().unsqueeze(0).cpu() # (1, N, 3) on CPU
85
+ triplane_feature2 = triplane_feature2.cpu() # Ensure on CPU
86
  with torch.no_grad():
87
  dec_verts = model.decoder(triplane_feature2, points_tensor)
88
  colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)