Spaces:
Sleeping
Sleeping
Colin Leong
commited on
Commit
·
84dfc7c
1
Parent(s):
869eec5
Add YouTube-ASL filtering, and ability to download points_dict and components list
Browse files
app.py
CHANGED
@@ -1,9 +1,13 @@
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
3 |
import numpy as np
|
4 |
from pose_format import Pose
|
|
|
5 |
from pose_format.pose_visualizer import PoseVisualizer
|
6 |
-
from pathlib import Path
|
7 |
from pyzstd import decompress
|
8 |
from PIL import Image
|
9 |
import mediapipe as mp
|
@@ -15,39 +19,47 @@ FACEMESH_CONTOURS_POINTS = [
|
|
15 |
set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup])
|
16 |
)
|
17 |
]
|
|
|
18 |
|
19 |
-
def
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
p2=("POSE_LANDMARKS", "LEFT_SHOULDER"),
|
24 |
-
)
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
)
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
p1=("pose_keypoints_2d", "RShoulder"), p2=("pose_keypoints_2d", "LShoulder")
|
34 |
-
)
|
35 |
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
51 |
|
52 |
|
53 |
# @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
|
@@ -60,7 +72,7 @@ def load_pose(uploaded_file: UploadedFile) -> Pose:
|
|
60 |
return Pose.read(uploaded_file.read())
|
61 |
|
62 |
|
63 |
-
@st.cache_data(hash_funcs={Pose: lambda p: np.
|
64 |
def get_pose_frames(pose: Pose, transparency: bool = False):
|
65 |
v = PoseVisualizer(pose)
|
66 |
frames = [frame_data for frame_data in v.draw()]
|
@@ -73,7 +85,13 @@ def get_pose_frames(pose: Pose, transparency: bool = False):
|
|
73 |
return frames, images
|
74 |
|
75 |
|
76 |
-
def get_pose_gif(
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
if fps is not None:
|
78 |
pose.body.fps = fps
|
79 |
v = PoseVisualizer(pose)
|
@@ -89,37 +107,42 @@ st.write(
|
|
89 |
st.write(
|
90 |
"I made this app to help me visualize and understand the format, including different 'components' and 'points', and what they are named."
|
91 |
)
|
92 |
-
st.write(
|
|
|
|
|
93 |
uploaded_file = st.file_uploader("Upload a .pose file", type=[".pose", ".pose.zst"])
|
94 |
|
95 |
|
96 |
if uploaded_file is not None:
|
97 |
with st.spinner(f"Loading {uploaded_file.name}"):
|
98 |
pose = load_pose(uploaded_file)
|
|
|
99 |
frames, images = get_pose_frames(pose=pose)
|
100 |
st.success("Done loading!")
|
101 |
-
|
102 |
st.write("### File Info")
|
103 |
with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
|
104 |
st.write(pose.header)
|
105 |
|
106 |
st.write(f"### Selection")
|
107 |
component_selection = st.radio(
|
108 |
-
"How to select components?", options=
|
109 |
)
|
110 |
|
111 |
component_names = [c.name for c in pose.header.components]
|
112 |
chosen_component_names = []
|
113 |
points_dict = {}
|
114 |
-
|
115 |
|
116 |
if component_selection == "manual":
|
117 |
-
|
118 |
|
119 |
chosen_component_names = st.pills(
|
120 |
-
"Select components to visualize",
|
|
|
|
|
|
|
121 |
)
|
122 |
-
|
123 |
for component in pose.header.components:
|
124 |
if component.name in chosen_component_names:
|
125 |
with st.expander(f"Points for {component.name}"):
|
@@ -128,32 +151,118 @@ if uploaded_file is not None:
|
|
128 |
options=component.points,
|
129 |
default=component.points,
|
130 |
)
|
131 |
-
if
|
|
|
|
|
132 |
points_dict[component.name] = selected_points
|
133 |
-
|
134 |
-
|
135 |
|
136 |
elif component_selection == "signclip":
|
137 |
st.write("Selected landmarks used for SignCLIP.")
|
138 |
-
chosen_component_names = [
|
|
|
|
|
|
|
|
|
|
|
139 |
points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
|
140 |
-
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Filter button logic
|
144 |
-
|
145 |
st.write("### Filter .pose File")
|
146 |
filtered = st.button("Apply Filter!")
|
147 |
if filtered:
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
st.session_state.filtered_pose = pose
|
152 |
|
153 |
-
filtered_pose = st.session_state.get(
|
154 |
if filtered_pose:
|
155 |
-
filtered_pose = st.session_state.get(
|
156 |
-
st.write(
|
157 |
st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
|
158 |
with st.expander("Show header"):
|
159 |
st.write(filtered_pose.header)
|
@@ -170,12 +279,20 @@ if uploaded_file is not None:
|
|
170 |
pose.write(f)
|
171 |
|
172 |
with pose_file_out.open("rb") as f:
|
173 |
-
st.download_button(
|
|
|
|
|
174 |
|
175 |
-
|
176 |
st.write("### Visualization")
|
177 |
-
step = st.select_slider(
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
start_frame, end_frame = st.slider(
|
180 |
"Select Frame Range",
|
181 |
0,
|
@@ -185,6 +302,13 @@ if uploaded_file is not None:
|
|
185 |
# Visualization button logic
|
186 |
if st.button("Visualize"):
|
187 |
# Load filtered pose if it exists; otherwise, use the unfiltered pose
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import json
|
3 |
+
from typing import Dict, Optional, List, Tuple
|
4 |
+
from collections import defaultdict
|
5 |
import streamlit as st
|
6 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
7 |
import numpy as np
|
8 |
from pose_format import Pose
|
9 |
+
from pose_format.utils.generic import pose_hide_legs, reduce_holistic
|
10 |
from pose_format.pose_visualizer import PoseVisualizer
|
|
|
11 |
from pyzstd import decompress
|
12 |
from PIL import Image
|
13 |
import mediapipe as mp
|
|
|
19 |
set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup])
|
20 |
)
|
21 |
]
|
22 |
+
COMPONENT_SELECTION_METHODS = ["manual", "signclip", "youtube-asl", "reduce_holistic"]
|
23 |
|
24 |
+
def download_json(data):
|
25 |
+
json_data = json.dumps(data)
|
26 |
+
json_bytes = json_data.encode('utf-8')
|
27 |
+
return json_bytes
|
|
|
|
|
28 |
|
29 |
+
def get_points_dict_and_components_with_index_list(
|
30 |
+
pose: Pose, landmark_indices: List[int], components_to_include: Optional[List[str]]
|
31 |
+
) -> Tuple[List[str], Dict[str, List[str]]]:
|
32 |
+
"""Used to get components/points if you only have a list of indices,
|
33 |
+
e.g. listed in a research paper like YouTube-ASL.
|
34 |
+
If you want to also explicitly specify component names, you can.
|
35 |
+
So for example, to get the two hands and the nose you could do the following:
|
36 |
+
c_names, points_dict = get_points_dict_and_components_with_index_list(pose,
|
37 |
+
landmark_indices=[0] # which is "NOSE" within POSE_LANDMARKS components
|
38 |
+
components_to_include=["LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS]
|
39 |
)
|
40 |
|
41 |
+
then you can just use get_components
|
42 |
+
filtered_pose = pose.get_components(c_names, points_dict)
|
|
|
|
|
43 |
|
44 |
+
"""
|
45 |
+
components_to_get = []
|
46 |
+
points_dict = defaultdict(list)
|
47 |
|
48 |
+
for c in pose.header.components:
|
49 |
+
for point_name in c.points:
|
50 |
+
point_index = pose.header.get_point_index(c.name, point_name)
|
51 |
+
if point_index in landmark_indices:
|
52 |
+
components_to_get.append(c.name)
|
53 |
+
points_dict[c.name].append(point_name)
|
54 |
+
# print(f"Point with index {point_index} has name {c.name}:{point_name}")
|
55 |
+
|
56 |
+
if components_to_include:
|
57 |
+
components_to_get.extend(components_to_include)
|
58 |
+
components_to_get = list(set(components_to_get))
|
59 |
+
# print("*********************")
|
60 |
+
# print(components_to_get)
|
61 |
+
# print(points_dict)
|
62 |
+
return components_to_get, points_dict
|
63 |
|
64 |
|
65 |
# @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
|
|
|
72 |
return Pose.read(uploaded_file.read())
|
73 |
|
74 |
|
75 |
+
@st.cache_data(hash_funcs={Pose: lambda p: np.asarray(p.body.data.data)})
|
76 |
def get_pose_frames(pose: Pose, transparency: bool = False):
|
77 |
v = PoseVisualizer(pose)
|
78 |
frames = [frame_data for frame_data in v.draw()]
|
|
|
85 |
return frames, images
|
86 |
|
87 |
|
88 |
+
def get_pose_gif(
|
89 |
+
pose: Pose,
|
90 |
+
step: int = 1,
|
91 |
+
start_frame: Optional[int] = None,
|
92 |
+
end_frame: Optional[int] = None,
|
93 |
+
fps: Optional[float] = None,
|
94 |
+
):
|
95 |
if fps is not None:
|
96 |
pose.body.fps = fps
|
97 |
v = PoseVisualizer(pose)
|
|
|
107 |
st.write(
|
108 |
"I made this app to help me visualize and understand the format, including different 'components' and 'points', and what they are named."
|
109 |
)
|
110 |
+
st.write(
|
111 |
+
"If you need a .pose file, here's one of [me doing a self-introduction](https://drive.google.com/file/d/1_L5sYVhONDBABuTmQUvjsl94LbFqzEyP/view?usp=sharing), and one of [me signing ASL 'HOUSE'](https://drive.google.com/file/d/1uggYqLyTA4XdDWaWsS9w5hKaPwW86IF_/view?usp=sharing)"
|
112 |
+
)
|
113 |
uploaded_file = st.file_uploader("Upload a .pose file", type=[".pose", ".pose.zst"])
|
114 |
|
115 |
|
116 |
if uploaded_file is not None:
|
117 |
with st.spinner(f"Loading {uploaded_file.name}"):
|
118 |
pose = load_pose(uploaded_file)
|
119 |
+
# st.write(pose.body.data.shape)
|
120 |
frames, images = get_pose_frames(pose=pose)
|
121 |
st.success("Done loading!")
|
122 |
+
|
123 |
st.write("### File Info")
|
124 |
with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
|
125 |
st.write(pose.header)
|
126 |
|
127 |
st.write(f"### Selection")
|
128 |
component_selection = st.radio(
|
129 |
+
"How to select components?", options=COMPONENT_SELECTION_METHODS
|
130 |
)
|
131 |
|
132 |
component_names = [c.name for c in pose.header.components]
|
133 |
chosen_component_names = []
|
134 |
points_dict = {}
|
135 |
+
HIDE_LEGS = False
|
136 |
|
137 |
if component_selection == "manual":
|
|
|
138 |
|
139 |
chosen_component_names = st.pills(
|
140 |
+
"Select components to visualize",
|
141 |
+
options=component_names,
|
142 |
+
default=component_names,
|
143 |
+
selection_mode="multi",
|
144 |
)
|
145 |
+
|
146 |
for component in pose.header.components:
|
147 |
if component.name in chosen_component_names:
|
148 |
with st.expander(f"Points for {component.name}"):
|
|
|
151 |
options=component.points,
|
152 |
default=component.points,
|
153 |
)
|
154 |
+
if (
|
155 |
+
selected_points != component.points
|
156 |
+
): # Only add entry if not all points are selected
|
157 |
points_dict[component.name] = selected_points
|
|
|
|
|
158 |
|
159 |
elif component_selection == "signclip":
|
160 |
st.write("Selected landmarks used for SignCLIP.")
|
161 |
+
chosen_component_names = [
|
162 |
+
"POSE_LANDMARKS",
|
163 |
+
"FACE_LANDMARKS",
|
164 |
+
"LEFT_HAND_LANDMARKS",
|
165 |
+
"RIGHT_HAND_LANDMARKS",
|
166 |
+
]
|
167 |
points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
|
|
|
168 |
|
169 |
+
elif component_selection == "youtube-asl":
|
170 |
+
st.write("Selected landmarks used for SignCLIP.")
|
171 |
+
# https://arxiv.org/pdf/2306.15162
|
172 |
+
# For each hand, we use all 21 landmark points.
|
173 |
+
# Colin: So that's
|
174 |
+
# For the pose, we use 6 landmark points, for the shoulders, elbows and hips
|
175 |
+
# These are indices 11, 12, 13, 14, 23, 24
|
176 |
+
# For the face, we use 37 landmark points, from the eyes, eyebrows, lips, and face outline.
|
177 |
+
# These are indices 0, 4, 13, 14, 17, 33, 37, 39, 46, 52, 55, 61, 64, 81, 82, 93, 133, 151, 152, 159, 172, 178,
|
178 |
+
# 181, 263, 269, 276, 282, 285, 291, 294, 311, 323, 362, 386, 397, 468, 473.
|
179 |
+
# Colin: note that these are with refine_face_landmarks on, and are relative to the component itself. Working it all out the result is:
|
180 |
+
components=['POSE_LANDMARKS', 'FACE_LANDMARKS', 'LEFT_HAND_LANDMARKS', 'RIGHT_HAND_LANDMARKS']
|
181 |
+
points_dict={
|
182 |
+
"POSE_LANDMARKS": [
|
183 |
+
"LEFT_SHOULDER",
|
184 |
+
"RIGHT_SHOULDER",
|
185 |
+
"LEFT_HIP",
|
186 |
+
"RIGHT_HIP",
|
187 |
+
"LEFT_ELBOW",
|
188 |
+
"RIGHT_ELBOW"
|
189 |
+
],
|
190 |
+
"FACE_LANDMARKS": [
|
191 |
+
"0",
|
192 |
+
"4",
|
193 |
+
"13",
|
194 |
+
"14",
|
195 |
+
"17",
|
196 |
+
"33",
|
197 |
+
"37",
|
198 |
+
"39",
|
199 |
+
"46",
|
200 |
+
"52",
|
201 |
+
"55",
|
202 |
+
"61",
|
203 |
+
"64",
|
204 |
+
"81",
|
205 |
+
"82",
|
206 |
+
"93",
|
207 |
+
"133",
|
208 |
+
"151",
|
209 |
+
"152",
|
210 |
+
"159",
|
211 |
+
"172",
|
212 |
+
"178",
|
213 |
+
"181",
|
214 |
+
"263",
|
215 |
+
"269",
|
216 |
+
"276",
|
217 |
+
"282",
|
218 |
+
"285",
|
219 |
+
"291",
|
220 |
+
"294",
|
221 |
+
"311",
|
222 |
+
"323",
|
223 |
+
"362",
|
224 |
+
"386",
|
225 |
+
"397",
|
226 |
+
"468", # 468 only exists with the refine_face_landmarks option on MediaPipe
|
227 |
+
"473", # 473 only exists with the refine_face_landmarks option on MediaPipe
|
228 |
+
]
|
229 |
+
}
|
230 |
|
231 |
# Filter button logic
|
232 |
+
# Filter section
|
233 |
st.write("### Filter .pose File")
|
234 |
filtered = st.button("Apply Filter!")
|
235 |
if filtered:
|
236 |
+
st.write(f"Filtering strategy: {component_selection}")
|
237 |
+
|
238 |
+
if component_selection == "reduce_holistic":
|
239 |
+
# st.write(f"reduce_holistic:")
|
240 |
+
pose = reduce_holistic(pose)
|
241 |
+
st.write("Used pose_format.reduce_holistic")
|
242 |
+
else:
|
243 |
+
pose = pose.get_components(components=chosen_component_names, points=points_dict if points_dict else None
|
244 |
+
)
|
245 |
+
with st.expander("Show component list and points dict used for get_components"):
|
246 |
+
st.write("##### Component names")
|
247 |
+
st.write(chosen_component_names)
|
248 |
+
st.write("##### Points dict")
|
249 |
+
st.write(points_dict)
|
250 |
+
|
251 |
+
with st.expander("How to replicate in pose-format"):
|
252 |
+
st.write("##### Usage:")
|
253 |
+
st.write("How to achieve the same result with pose-format library")
|
254 |
+
# points_dict_str = json.dumps(points_dict, indent=4)
|
255 |
+
usage_string = f"components={chosen_component_names}\npoints_dict={points_dict}\npose = pose.get_components(components=components, points=points_dict)"
|
256 |
+
st.code(usage_string)
|
257 |
+
|
258 |
+
if HIDE_LEGS:
|
259 |
+
pose = pose_hide_legs(pose, remove=True)
|
260 |
st.session_state.filtered_pose = pose
|
261 |
|
262 |
+
filtered_pose = st.session_state.get("filtered_pose", pose)
|
263 |
if filtered_pose:
|
264 |
+
filtered_pose = st.session_state.get("filtered_pose", pose)
|
265 |
+
st.write("#### Filtered .pose file")
|
266 |
st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
|
267 |
with st.expander("Show header"):
|
268 |
st.write(filtered_pose.header)
|
|
|
279 |
pose.write(f)
|
280 |
|
281 |
with pose_file_out.open("rb") as f:
|
282 |
+
st.download_button(
|
283 |
+
"Download Filtered Pose", f, file_name=pose_file_out.name
|
284 |
+
)
|
285 |
|
|
|
286 |
st.write("### Visualization")
|
287 |
+
step = st.select_slider(
|
288 |
+
"Step value to select every nth image", list(range(1, len(frames))), value=1
|
289 |
+
)
|
290 |
+
fps = st.slider(
|
291 |
+
"FPS for visualization",
|
292 |
+
min_value=1.0,
|
293 |
+
max_value=filtered_pose.body.fps,
|
294 |
+
value=filtered_pose.body.fps,
|
295 |
+
)
|
296 |
start_frame, end_frame = st.slider(
|
297 |
"Select Frame Range",
|
298 |
0,
|
|
|
302 |
# Visualization button logic
|
303 |
if st.button("Visualize"):
|
304 |
# Load filtered pose if it exists; otherwise, use the unfiltered pose
|
305 |
+
|
306 |
+
pose_bytes = get_pose_gif(
|
307 |
+
pose=filtered_pose,
|
308 |
+
step=step,
|
309 |
+
start_frame=start_frame,
|
310 |
+
end_frame=end_frame,
|
311 |
+
fps=fps,
|
312 |
+
)
|
313 |
+
if pose_bytes is not None:
|
314 |
+
st.image(pose_bytes)
|