Spaces:
Runtime error
Runtime error
File size: 3,648 Bytes
06bba7e 53625b9 06bba7e 53625b9 db4fb82 53625b9 f0ae1c5 53625b9 db4fb82 53625b9 db4fb82 02f3e52 db4fb82 02f3e52 4c51593 9b3735c 4c51593 53625b9 06bba7e 0aa5626 06bba7e 4c51593 06bba7e db4fb82 4c51593 06bba7e 1faee00 06bba7e 4c51593 06bba7e 0aa5626 06bba7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import streamlit as st
from huggingface_hub import HfFolder
HfFolder().save_token(st.secrets['etoken'])
import numpy
import trimesh
import objaverse
import openshape
import misc_utils
import plotly.graph_objects as go
@st.cache_resource
def load_openshape(name):
return openshape.load_pc_encoder(name)
f32 = numpy.float32
model_b32 = openshape.load_pc_encoder('openshape-pointbert-vitb32-rgb')
model_l14 = openshape.load_pc_encoder('openshape-pointbert-vitl14-rgb')
model_g14 = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb')
st.title("OpenShape Demo")
objaid = st.text_input("Enter an Objaverse ID")
model = st.file_uploader("Or upload a model (.glb/.obj/.ply)")
npy = st.file_uploader("Or upload a point cloud numpy array (.npy of Nx3 XYZ or Nx6 XYZRGB)")
swap_yz_axes = st.checkbox("Swap Y/Z axes of input (Y is up for OpenShape)")
prog = st.progress(0.0, "Idle")
def load_data():
# load the model
prog.progress(0.05, "Preparing Point Cloud")
if npy is not None:
pc: numpy.ndarray = numpy.load(npy)
elif model is not None:
pc = misc_utils.trimesh_to_pc(trimesh.load(model, model.name.split(".")[-1]))
elif objaid:
prog.progress(0.1, "Downloading Objaverse Object")
objamodel = objaverse.load_objects([objaid])[objaid]
prog.progress(0.2, "Preparing Point Cloud")
pc = misc_utils.trimesh_to_pc(trimesh.load(objamodel))
else:
raise ValueError("You have to supply 3D input!")
prog.progress(0.25, "Preprocessing Point Cloud")
assert pc.ndim == 2, "invalid pc shape: ndim = %d != 2" % pc.ndim
assert pc.shape[1] in [3, 6], "invalid pc shape: should have 3/6 channels, got %d" % pc.shape[1]
if swap_yz_axes:
pc[:, [1, 2]] = pc[:, [2, 1]]
pc[:, :3] = pc[:, :3] - numpy.mean(pc[:, :3], axis=0)
pc[:, :3] = pc[:, :3] / numpy.linalg.norm(pc[:, :3], axis=-1).max()
if pc.shape[1] == 3:
pc = numpy.concatenate([pc, numpy.ones_like(pc)], axis=-1)
prog.progress(0.3, "Preprocessed Point Cloud")
return pc.astype(f32)
def render_pc(pc):
rand = numpy.random.permutation(len(pc))[:2048]
pc = pc[rand]
rgb = (pc[:, 3:] * 255).astype(numpy.uint8)
g = go.Scatter3d(
x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
mode='markers',
marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]),
)
fig = go.Figure(data=[g])
fig.update_layout(scene_camera=dict(up=dict(x=0, y=1, z=0)))
col1, col2 = st.columns(2)
with col1:
st.plotly_chart(fig, use_container_width=True)
# st.caption("Point Cloud Preview")
return col2
try:
tab_cls, tab_cap = st.tabs(["Classification", "Point Cloud Captioning"])
with tab_cls:
if st.button("Run Classification on LVIS Categories"):
pc = load_data()
col2 = render_pc(pc)
prog.progress(0.5, "Running Classification")
pred = openshape.pred_lvis_sims(model_g14, pc)
with col2:
for i, (cat, sim) in zip(range(5), pred.items()):
st.text(cat)
st.caption("Similarity %.4f" % sim)
prog.progress(1.0, "Idle")
with tab_cap:
cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0)
if st.button("Generate a Caption"):
pc = load_data()
col2 = render_pc(pc)
prog.progress(0.5, "Running Generation")
cap = openshape.pc_caption(model_b32, pc, cond_scale)
st.text(cap)
prog.progress(1.0, "Idle")
except Exception as exc:
st.error(repr(exc))
|