eliphatfs commited on
Commit
db4fb82
·
1 Parent(s): f981cf0
Files changed (3) hide show
  1. app.py +21 -27
  2. misc_utils.py +5 -4
  3. openshape/classification.py +1 -1
app.py CHANGED
@@ -36,12 +36,12 @@ def load_data():
36
  if npy is not None:
37
  pc: numpy.ndarray = numpy.load(npy)
38
  elif model is not None:
39
- pc = misc_utils.model_to_pc(misc_utils.as_mesh(trimesh.load(model)))
40
  elif objaid:
41
  prog.progress(0.1, "Downloading Objaverse Object")
42
  objamodel = objaverse.load_objects([objaid])[objaid]
43
  prog.progress(0.2, "Preparing Point Cloud")
44
- pc = misc_utils.model_to_pc(misc_utils.as_mesh(trimesh.load(objamodel)))
45
  else:
46
  raise ValueError("You have to supply 3D input!")
47
  prog.progress(0.25, "Preprocessing Point Cloud")
@@ -57,20 +57,17 @@ def load_data():
57
  return pc.astype(f32)
58
 
59
 
60
- def render_pc(ncols, col, pc):
61
  pc = pc[:2048]
62
- cols = st.columns(ncols)
63
- with cols[col]:
64
- rgb = (pc[:, 3:] * 255).astype(numpy.uint8)
65
- g = go.Scatter3d(
66
- x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
67
- mode='markers',
68
- marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]),
69
- )
70
- fig = go.Figure(data=[g])
71
- st.plotly_chart(fig)
72
- st.caption("Point Cloud Preview")
73
- return cols
74
 
75
 
76
  try:
@@ -79,13 +76,12 @@ try:
79
  with tab_cls:
80
  if st.button("Run Classification on LVIS Categories"):
81
  pc = load_data()
82
- col1, col2 = render_pc(2, 0, pc)
83
  prog.progress(0.5, "Running Classification")
84
- with col2:
85
- pred = openshape.pred_lvis_sims(model_g14, pc)
86
- for i, (cat, sim) in zip(range(5), pred.items()):
87
- st.text(cat)
88
- st.caption("Similarity %.4f" % sim)
89
  prog.progress(1.0, "Idle")
90
 
91
  with tab_pc2img:
@@ -97,25 +93,23 @@ try:
97
  height = st.slider('Height', 128, 512, step=32)
98
  if st.button("Generate"):
99
  pc = load_data()
100
- col1, col2 = render_pc(2, 0, pc)
101
  prog.progress(0.49, "Running Generation")
102
  img = openshape.pc_to_image(
103
  model_l14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
104
  lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
105
  )
106
- with col2:
107
- st.image(img)
108
  prog.progress(1.0, "Idle")
109
 
110
  with tab_cap:
111
  cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 1.0)
112
  if st.button("Generate a Caption"):
113
  pc = load_data()
114
- col1, col2 = render_pc(2, 0, pc)
115
  prog.progress(0.5, "Running Generation")
116
  cap = openshape.pc_caption(model_b32, pc, cond_scale)
117
- with col2:
118
- st.text(cap)
119
  prog.progress(1.0, "Idle")
120
  except Exception as exc:
121
  st.error(repr(exc))
 
36
  if npy is not None:
37
  pc: numpy.ndarray = numpy.load(npy)
38
  elif model is not None:
39
+ pc = misc_utils.trimesh_to_pc(trimesh.load(model, model.name.split(".")[-1]))
40
  elif objaid:
41
  prog.progress(0.1, "Downloading Objaverse Object")
42
  objamodel = objaverse.load_objects([objaid])[objaid]
43
  prog.progress(0.2, "Preparing Point Cloud")
44
+ pc = misc_utils.trimesh_to_pc(trimesh.load(objamodel))
45
  else:
46
  raise ValueError("You have to supply 3D input!")
47
  prog.progress(0.25, "Preprocessing Point Cloud")
 
57
  return pc.astype(f32)
58
 
59
 
60
+ def render_pc(pc):
61
  pc = pc[:2048]
62
+ rgb = (pc[:, 3:] * 255).astype(numpy.uint8)
63
+ g = go.Scatter3d(
64
+ x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
65
+ mode='markers',
66
+ marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]),
67
+ )
68
+ fig = go.Figure(data=[g])
69
+ st.plotly_chart(fig)
70
+ # st.caption("Point Cloud Preview")
 
 
 
71
 
72
 
73
  try:
 
76
  with tab_cls:
77
  if st.button("Run Classification on LVIS Categories"):
78
  pc = load_data()
79
+ render_pc(2, 0, pc)
80
  prog.progress(0.5, "Running Classification")
81
+ pred = openshape.pred_lvis_sims(model_g14, pc)
82
+ for i, (cat, sim) in zip(range(5), pred.items()):
83
+ st.text(cat)
84
+ st.caption("Similarity %.4f" % sim)
 
85
  prog.progress(1.0, "Idle")
86
 
87
  with tab_pc2img:
 
93
  height = st.slider('Height', 128, 512, step=32)
94
  if st.button("Generate"):
95
  pc = load_data()
96
+ render_pc(2, 0, pc)
97
  prog.progress(0.49, "Running Generation")
98
  img = openshape.pc_to_image(
99
  model_l14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
100
  lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
101
  )
102
+ st.image(img)
 
103
  prog.progress(1.0, "Idle")
104
 
105
  with tab_cap:
106
  cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 1.0)
107
  if st.button("Generate a Caption"):
108
  pc = load_data()
109
+ render_pc(2, 0, pc)
110
  prog.progress(0.5, "Running Generation")
111
  cap = openshape.pc_caption(model_b32, pc, cond_scale)
112
+ st.text(cap)
 
113
  prog.progress(1.0, "Idle")
114
  except Exception as exc:
115
  st.error(repr(exc))
misc_utils.py CHANGED
@@ -3,6 +3,7 @@ import trimesh
3
  import trimesh.sample
4
  import trimesh.visual
5
  import trimesh.proximity
 
6
  import matplotlib.pyplot as plotlib
7
 
8
 
@@ -42,16 +43,16 @@ def model_to_pc(mesh: trimesh.Trimesh, n_sample_points=10000):
42
  return numpy.concatenate([numpy.array(pcd, f32), numpy.array(rgba, f32)[:, :3]], axis=-1)
43
 
44
 
45
- def as_mesh(scene_or_mesh):
46
  if isinstance(scene_or_mesh, trimesh.Scene):
47
  meshes = [
48
- trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
49
  for g in scene_or_mesh.geometry.values()
50
  if isinstance(g, trimesh.Trimesh)
51
  ]
52
  if not len(meshes):
53
  return None
54
- return trimesh.util.concatenate(meshes)
55
  else:
56
  assert isinstance(scene_or_mesh, trimesh.Trimesh)
57
- return scene_or_mesh
 
3
  import trimesh.sample
4
  import trimesh.visual
5
  import trimesh.proximity
6
+ import streamlit as st
7
  import matplotlib.pyplot as plotlib
8
 
9
 
 
43
  return numpy.concatenate([numpy.array(pcd, f32), numpy.array(rgba, f32)[:, :3]], axis=-1)
44
 
45
 
46
+ def trimesh_to_pc(scene_or_mesh):
47
  if isinstance(scene_or_mesh, trimesh.Scene):
48
  meshes = [
49
+ model_to_pc(trimesh.Trimesh(vertices=g.vertices, faces=g.faces), 10000 // len(scene_or_mesh.geometry))
50
  for g in scene_or_mesh.geometry.values()
51
  if isinstance(g, trimesh.Trimesh)
52
  ]
53
  if not len(meshes):
54
  return None
55
+ return numpy.concatenate(meshes)
56
  else:
57
  assert isinstance(scene_or_mesh, trimesh.Trimesh)
58
+ return model_to_pc(scene_or_mesh, 10000)
openshape/classification.py CHANGED
@@ -10,4 +10,4 @@ def pred_lvis_sims(pc_encoder: torch.nn.Module, pc):
10
  enc = pc_encoder(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
11
  sim = torch.matmul(F.normalize(lvis.feats, dim=-1), F.normalize(enc, dim=-1).squeeze())
12
  argsort = torch.argsort(sim, descending=True)
13
- return OrderedDict((lvis.categories[i], sim[i]) for i in argsort)
 
10
  enc = pc_encoder(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
11
  sim = torch.matmul(F.normalize(lvis.feats, dim=-1), F.normalize(enc, dim=-1).squeeze())
12
  argsort = torch.argsort(sim, descending=True)
13
+ return OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))