Commit
·
dcb102c
1
Parent(s):
f6ac519
Ensure tensors are on CPU in generate3d function of inference.py for consistent processing. Updated points_tensor and triplane_feature2 to prevent device-related issues during inference.
Browse files- inference.py +2 -1
inference.py
CHANGED
@@ -81,7 +81,8 @@ def generate3d(model, rgb, ccm, device):
|
|
81 |
mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
|
82 |
points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
|
83 |
# Use the model's decoder and rgbMlp to get color for each sampled point
|
84 |
-
points_tensor = torch.from_numpy(points).float().unsqueeze(0) # (1, N, 3)
|
|
|
85 |
with torch.no_grad():
|
86 |
dec_verts = model.decoder(triplane_feature2, points_tensor)
|
87 |
colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)
|
|
|
81 |
mesh_trimesh = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f, process=False)
|
82 |
points, face_indices = trimesh.sample.sample_surface(mesh_trimesh, tex_res[0]*tex_res[1])
|
83 |
# Use the model's decoder and rgbMlp to get color for each sampled point
|
84 |
+
points_tensor = torch.from_numpy(points).float().unsqueeze(0).cpu() # (1, N, 3) on CPU
|
85 |
+
triplane_feature2 = triplane_feature2.cpu() # Ensure on CPU
|
86 |
with torch.no_grad():
|
87 |
dec_verts = model.decoder(triplane_feature2, points_tensor)
|
88 |
colors = model.rgbMlp(dec_verts).squeeze().cpu().numpy() # (N, 3)
|