Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
8a29dd3
·
1 Parent(s): 263eddc

Refactor color scaling and assignment in generate3d function

Browse files

- Added debug prints to monitor the min/max values of orig_colors before and after scaling.
- Commented out the scaling operation to allow for flexibility if orig_colors are already in the [0, 1] range.
- Updated color assignment to use only the nearest neighbor for improved accuracy in color mapping.

Files changed (1) hide show
  1. inference.py +6 -2
inference.py CHANGED
@@ -68,7 +68,10 @@ def generate3d(model, rgb, ccm, device):
68
  with torch.no_grad():
69
  dec_verts = model.decoder(triplane_feature2, orig_verts_tensor)
70
  orig_colors = model.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
71
- orig_colors = (orig_colors * 0.5 + 0.5).clip(0, 1) # scale to [0, 1]
 
 
 
72
  verts, faces = clean_mesh(
73
  orig_verts.astype(np.float32),
74
  data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
@@ -83,7 +86,8 @@ def generate3d(model, rgb, ccm, device):
83
  # For each new vertex, find the nearest old vertex and copy its color
84
  k = 3
85
  dists, idxs = tree.query(verts, k=k)
86
- new_colors = np.mean(orig_colors[idxs], axis=1)
 
87
 
88
  # Create the new mesh with colors
89
  mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_colors=new_colors)
 
68
  with torch.no_grad():
69
  dec_verts = model.decoder(triplane_feature2, orig_verts_tensor)
70
  orig_colors = model.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
71
+ print('orig_colors min/max BEFORE scaling:', orig_colors.min(), orig_colors.max())
72
+ # Comment out the scaling below if orig_colors is already in [0, 1]
73
+ # orig_colors = (orig_colors * 0.5 + 0.5).clip(0, 1) # scale to [0, 1]
74
+ print('orig_colors min/max AFTER scaling:', orig_colors.min(), orig_colors.max())
75
  verts, faces = clean_mesh(
76
  orig_verts.astype(np.float32),
77
  data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
 
86
  # For each new vertex, find the nearest old vertex and copy its color
87
  k = 3
88
  dists, idxs = tree.query(verts, k=k)
89
+ # Use only the nearest neighbor for color assignment
90
+ new_colors = orig_colors[idxs[:, 0]]
91
 
92
  # Create the new mesh with colors
93
  mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_colors=new_colors)