Colin Leong commited on
Commit
84dfc7c
·
1 Parent(s): 869eec5

Add YouTube-ASL filtering, and ability to download points_dict and components list

Browse files
Files changed (1) hide show
  1. app.py +180 -56
app.py CHANGED
@@ -1,9 +1,13 @@
 
 
 
 
1
  import streamlit as st
2
  from streamlit.runtime.uploaded_file_manager import UploadedFile
3
  import numpy as np
4
  from pose_format import Pose
 
5
  from pose_format.pose_visualizer import PoseVisualizer
6
- from pathlib import Path
7
  from pyzstd import decompress
8
  from PIL import Image
9
  import mediapipe as mp
@@ -15,39 +19,47 @@ FACEMESH_CONTOURS_POINTS = [
15
  set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup])
16
  )
17
  ]
 
18
 
19
- def pose_normalization_info(pose_header):
20
- if pose_header.components[0].name == "POSE_LANDMARKS":
21
- return pose_header.normalization_info(
22
- p1=("POSE_LANDMARKS", "RIGHT_SHOULDER"),
23
- p2=("POSE_LANDMARKS", "LEFT_SHOULDER"),
24
- )
25
 
26
- if pose_header.components[0].name == "BODY_135":
27
- return pose_header.normalization_info(
28
- p1=("BODY_135", "RShoulder"), p2=("BODY_135", "LShoulder")
 
 
 
 
 
 
 
29
  )
30
 
31
- if pose_header.components[0].name == "pose_keypoints_2d":
32
- return pose_header.normalization_info(
33
- p1=("pose_keypoints_2d", "RShoulder"), p2=("pose_keypoints_2d", "LShoulder")
34
- )
35
 
 
 
 
36
 
37
- def pose_hide_legs(pose):
38
- if pose.header.components[0].name == "POSE_LANDMARKS":
39
- point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX"]
40
- # pylint: disable=protected-access
41
- points = [
42
- pose.header._get_point_index("POSE_LANDMARKS", side + "_" + n)
43
- for n in point_names
44
- for side in ["LEFT", "RIGHT"]
45
- ]
46
- pose.body.confidence[:, :, points] = 0
47
- pose.body.data[:, :, points, :] = 0
48
- return pose
49
- else:
50
- raise ValueError("Unknown pose header schema for hiding legs")
 
51
 
52
 
53
  # @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
@@ -60,7 +72,7 @@ def load_pose(uploaded_file: UploadedFile) -> Pose:
60
  return Pose.read(uploaded_file.read())
61
 
62
 
63
- @st.cache_data(hash_funcs={Pose: lambda p: np.array(p.body.data)})
64
  def get_pose_frames(pose: Pose, transparency: bool = False):
65
  v = PoseVisualizer(pose)
66
  frames = [frame_data for frame_data in v.draw()]
@@ -73,7 +85,13 @@ def get_pose_frames(pose: Pose, transparency: bool = False):
73
  return frames, images
74
 
75
 
76
- def get_pose_gif(pose: Pose, step: int = 1, start_frame:int=None, end_frame:int=None, fps: int = None):
 
 
 
 
 
 
77
  if fps is not None:
78
  pose.body.fps = fps
79
  v = PoseVisualizer(pose)
@@ -89,37 +107,42 @@ st.write(
89
  st.write(
90
  "I made this app to help me visualize and understand the format, including different 'components' and 'points', and what they are named."
91
  )
92
- st.write("If you need a .pose file, here's one of [me doing a self-introduction](https://drive.google.com/file/d/1_L5sYVhONDBABuTmQUvjsl94LbFqzEyP/view?usp=sharing), and one of [me signing ASL 'HOUSE'](https://drive.google.com/file/d/1uggYqLyTA4XdDWaWsS9w5hKaPwW86IF_/view?usp=sharing)")
 
 
93
  uploaded_file = st.file_uploader("Upload a .pose file", type=[".pose", ".pose.zst"])
94
 
95
 
96
  if uploaded_file is not None:
97
  with st.spinner(f"Loading {uploaded_file.name}"):
98
  pose = load_pose(uploaded_file)
 
99
  frames, images = get_pose_frames(pose=pose)
100
  st.success("Done loading!")
101
-
102
  st.write("### File Info")
103
  with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
104
  st.write(pose.header)
105
 
106
  st.write(f"### Selection")
107
  component_selection = st.radio(
108
- "How to select components?", options=["manual", "signclip"]
109
  )
110
 
111
  component_names = [c.name for c in pose.header.components]
112
  chosen_component_names = []
113
  points_dict = {}
114
- hide_legs = False
115
 
116
  if component_selection == "manual":
117
-
118
 
119
  chosen_component_names = st.pills(
120
- "Select components to visualize", options=component_names, default=component_names,selection_mode="multi"
 
 
 
121
  )
122
-
123
  for component in pose.header.components:
124
  if component.name in chosen_component_names:
125
  with st.expander(f"Points for {component.name}"):
@@ -128,32 +151,118 @@ if uploaded_file is not None:
128
  options=component.points,
129
  default=component.points,
130
  )
131
- if selected_points != component.points: # Only add entry if not all points are selected
 
 
132
  points_dict[component.name] = selected_points
133
-
134
-
135
 
136
  elif component_selection == "signclip":
137
  st.write("Selected landmarks used for SignCLIP.")
138
- chosen_component_names = ["POSE_LANDMARKS", "FACE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]
 
 
 
 
 
139
  points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
140
-
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # Filter button logic
144
- # Filter section
145
  st.write("### Filter .pose File")
146
  filtered = st.button("Apply Filter!")
147
  if filtered:
148
- pose = pose.get_components(chosen_component_names, points=points_dict if points_dict else None)
149
- if hide_legs:
150
- pose = pose_hide_legs(pose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  st.session_state.filtered_pose = pose
152
 
153
- filtered_pose = st.session_state.get('filtered_pose', pose)
154
  if filtered_pose:
155
- filtered_pose = st.session_state.get('filtered_pose', pose)
156
- st.write(f"#### Filtered .pose file")
157
  st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
158
  with st.expander("Show header"):
159
  st.write(filtered_pose.header)
@@ -170,12 +279,20 @@ if uploaded_file is not None:
170
  pose.write(f)
171
 
172
  with pose_file_out.open("rb") as f:
173
- st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
 
 
174
 
175
-
176
  st.write("### Visualization")
177
- step = st.select_slider("Step value to select every nth image", list(range(1, len(frames))), value=1)
178
- fps = st.slider("FPS for visualization", min_value=1.0, max_value=filtered_pose.body.fps, value=filtered_pose.body.fps)
 
 
 
 
 
 
 
179
  start_frame, end_frame = st.slider(
180
  "Select Frame Range",
181
  0,
@@ -185,6 +302,13 @@ if uploaded_file is not None:
185
  # Visualization button logic
186
  if st.button("Visualize"):
187
  # Load filtered pose if it exists; otherwise, use the unfiltered pose
188
-
189
-
190
- st.image(get_pose_gif(pose=filtered_pose, step=step, start_frame=start_frame, end_frame=end_frame, fps=fps))
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ from typing import Dict, Optional, List, Tuple
4
+ from collections import defaultdict
5
  import streamlit as st
6
  from streamlit.runtime.uploaded_file_manager import UploadedFile
7
  import numpy as np
8
  from pose_format import Pose
9
+ from pose_format.utils.generic import pose_hide_legs, reduce_holistic
10
  from pose_format.pose_visualizer import PoseVisualizer
 
11
  from pyzstd import decompress
12
  from PIL import Image
13
  import mediapipe as mp
 
19
  set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup])
20
  )
21
  ]
22
+ COMPONENT_SELECTION_METHODS = ["manual", "signclip", "youtube-asl", "reduce_holistic"]
23
 
24
+ def download_json(data):
25
+ json_data = json.dumps(data)
26
+ json_bytes = json_data.encode('utf-8')
27
+ return json_bytes
 
 
28
 
29
+ def get_points_dict_and_components_with_index_list(
30
+ pose: Pose, landmark_indices: List[int], components_to_include: Optional[List[str]]
31
+ ) -> Tuple[List[str], Dict[str, List[str]]]:
32
+ """Used to get components/points if you only have a list of indices,
33
+ e.g. listed in a research paper like YouTube-ASL.
34
+ If you want to also explicitly specify component names, you can.
35
+ So for example, to get the two hands and the nose you could do the following:
36
+ c_names, points_dict = get_points_dict_and_components_with_index_list(pose,
37
+ landmark_indices=[0] # which is "NOSE" within POSE_LANDMARKS components
38
+ components_to_include=["LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS]
39
  )
40
 
41
+ then you can just use get_components
42
+ filtered_pose = pose.get_components(c_names, points_dict)
 
 
43
 
44
+ """
45
+ components_to_get = []
46
+ points_dict = defaultdict(list)
47
 
48
+ for c in pose.header.components:
49
+ for point_name in c.points:
50
+ point_index = pose.header.get_point_index(c.name, point_name)
51
+ if point_index in landmark_indices:
52
+ components_to_get.append(c.name)
53
+ points_dict[c.name].append(point_name)
54
+ # print(f"Point with index {point_index} has name {c.name}:{point_name}")
55
+
56
+ if components_to_include:
57
+ components_to_get.extend(components_to_include)
58
+ components_to_get = list(set(components_to_get))
59
+ # print("*********************")
60
+ # print(components_to_get)
61
+ # print(points_dict)
62
+ return components_to_get, points_dict
63
 
64
 
65
  # @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
 
72
  return Pose.read(uploaded_file.read())
73
 
74
 
75
+ @st.cache_data(hash_funcs={Pose: lambda p: np.asarray(p.body.data.data)})
76
  def get_pose_frames(pose: Pose, transparency: bool = False):
77
  v = PoseVisualizer(pose)
78
  frames = [frame_data for frame_data in v.draw()]
 
85
  return frames, images
86
 
87
 
88
+ def get_pose_gif(
89
+ pose: Pose,
90
+ step: int = 1,
91
+ start_frame: Optional[int] = None,
92
+ end_frame: Optional[int] = None,
93
+ fps: Optional[float] = None,
94
+ ):
95
  if fps is not None:
96
  pose.body.fps = fps
97
  v = PoseVisualizer(pose)
 
107
  st.write(
108
  "I made this app to help me visualize and understand the format, including different 'components' and 'points', and what they are named."
109
  )
110
+ st.write(
111
+ "If you need a .pose file, here's one of [me doing a self-introduction](https://drive.google.com/file/d/1_L5sYVhONDBABuTmQUvjsl94LbFqzEyP/view?usp=sharing), and one of [me signing ASL 'HOUSE'](https://drive.google.com/file/d/1uggYqLyTA4XdDWaWsS9w5hKaPwW86IF_/view?usp=sharing)"
112
+ )
113
  uploaded_file = st.file_uploader("Upload a .pose file", type=[".pose", ".pose.zst"])
114
 
115
 
116
  if uploaded_file is not None:
117
  with st.spinner(f"Loading {uploaded_file.name}"):
118
  pose = load_pose(uploaded_file)
119
+ # st.write(pose.body.data.shape)
120
  frames, images = get_pose_frames(pose=pose)
121
  st.success("Done loading!")
122
+
123
  st.write("### File Info")
124
  with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
125
  st.write(pose.header)
126
 
127
  st.write(f"### Selection")
128
  component_selection = st.radio(
129
+ "How to select components?", options=COMPONENT_SELECTION_METHODS
130
  )
131
 
132
  component_names = [c.name for c in pose.header.components]
133
  chosen_component_names = []
134
  points_dict = {}
135
+ HIDE_LEGS = False
136
 
137
  if component_selection == "manual":
 
138
 
139
  chosen_component_names = st.pills(
140
+ "Select components to visualize",
141
+ options=component_names,
142
+ default=component_names,
143
+ selection_mode="multi",
144
  )
145
+
146
  for component in pose.header.components:
147
  if component.name in chosen_component_names:
148
  with st.expander(f"Points for {component.name}"):
 
151
  options=component.points,
152
  default=component.points,
153
  )
154
+ if (
155
+ selected_points != component.points
156
+ ): # Only add entry if not all points are selected
157
  points_dict[component.name] = selected_points
 
 
158
 
159
  elif component_selection == "signclip":
160
  st.write("Selected landmarks used for SignCLIP.")
161
+ chosen_component_names = [
162
+ "POSE_LANDMARKS",
163
+ "FACE_LANDMARKS",
164
+ "LEFT_HAND_LANDMARKS",
165
+ "RIGHT_HAND_LANDMARKS",
166
+ ]
167
  points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
 
168
 
169
+ elif component_selection == "youtube-asl":
170
+ st.write("Selected landmarks used for SignCLIP.")
171
+ # https://arxiv.org/pdf/2306.15162
172
+ # For each hand, we use all 21 landmark points.
173
+ # Colin: So that's
174
+ # For the pose, we use 6 landmark points, for the shoulders, elbows and hips
175
+ # These are indices 11, 12, 13, 14, 23, 24
176
+ # For the face, we use 37 landmark points, from the eyes, eyebrows, lips, and face outline.
177
+ # These are indices 0, 4, 13, 14, 17, 33, 37, 39, 46, 52, 55, 61, 64, 81, 82, 93, 133, 151, 152, 159, 172, 178,
178
+ # 181, 263, 269, 276, 282, 285, 291, 294, 311, 323, 362, 386, 397, 468, 473.
179
+ # Colin: note that these are with refine_face_landmarks on, and are relative to the component itself. Working it all out the result is:
180
+ components=['POSE_LANDMARKS', 'FACE_LANDMARKS', 'LEFT_HAND_LANDMARKS', 'RIGHT_HAND_LANDMARKS']
181
+ points_dict={
182
+ "POSE_LANDMARKS": [
183
+ "LEFT_SHOULDER",
184
+ "RIGHT_SHOULDER",
185
+ "LEFT_HIP",
186
+ "RIGHT_HIP",
187
+ "LEFT_ELBOW",
188
+ "RIGHT_ELBOW"
189
+ ],
190
+ "FACE_LANDMARKS": [
191
+ "0",
192
+ "4",
193
+ "13",
194
+ "14",
195
+ "17",
196
+ "33",
197
+ "37",
198
+ "39",
199
+ "46",
200
+ "52",
201
+ "55",
202
+ "61",
203
+ "64",
204
+ "81",
205
+ "82",
206
+ "93",
207
+ "133",
208
+ "151",
209
+ "152",
210
+ "159",
211
+ "172",
212
+ "178",
213
+ "181",
214
+ "263",
215
+ "269",
216
+ "276",
217
+ "282",
218
+ "285",
219
+ "291",
220
+ "294",
221
+ "311",
222
+ "323",
223
+ "362",
224
+ "386",
225
+ "397",
226
+ "468", # 468 only exists with the refine_face_landmarks option on MediaPipe
227
+ "473", # 473 only exists with the refine_face_landmarks option on MediaPipe
228
+ ]
229
+ }
230
 
231
  # Filter button logic
232
+ # Filter section
233
  st.write("### Filter .pose File")
234
  filtered = st.button("Apply Filter!")
235
  if filtered:
236
+ st.write(f"Filtering strategy: {component_selection}")
237
+
238
+ if component_selection == "reduce_holistic":
239
+ # st.write(f"reduce_holistic:")
240
+ pose = reduce_holistic(pose)
241
+ st.write("Used pose_format.reduce_holistic")
242
+ else:
243
+ pose = pose.get_components(components=chosen_component_names, points=points_dict if points_dict else None
244
+ )
245
+ with st.expander("Show component list and points dict used for get_components"):
246
+ st.write("##### Component names")
247
+ st.write(chosen_component_names)
248
+ st.write("##### Points dict")
249
+ st.write(points_dict)
250
+
251
+ with st.expander("How to replicate in pose-format"):
252
+ st.write("##### Usage:")
253
+ st.write("How to achieve the same result with pose-format library")
254
+ # points_dict_str = json.dumps(points_dict, indent=4)
255
+ usage_string = f"components={chosen_component_names}\npoints_dict={points_dict}\npose = pose.get_components(components=components, points=points_dict)"
256
+ st.code(usage_string)
257
+
258
+ if HIDE_LEGS:
259
+ pose = pose_hide_legs(pose, remove=True)
260
  st.session_state.filtered_pose = pose
261
 
262
+ filtered_pose = st.session_state.get("filtered_pose", pose)
263
  if filtered_pose:
264
+ filtered_pose = st.session_state.get("filtered_pose", pose)
265
+ st.write("#### Filtered .pose file")
266
  st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
267
  with st.expander("Show header"):
268
  st.write(filtered_pose.header)
 
279
  pose.write(f)
280
 
281
  with pose_file_out.open("rb") as f:
282
+ st.download_button(
283
+ "Download Filtered Pose", f, file_name=pose_file_out.name
284
+ )
285
 
 
286
  st.write("### Visualization")
287
+ step = st.select_slider(
288
+ "Step value to select every nth image", list(range(1, len(frames))), value=1
289
+ )
290
+ fps = st.slider(
291
+ "FPS for visualization",
292
+ min_value=1.0,
293
+ max_value=filtered_pose.body.fps,
294
+ value=filtered_pose.body.fps,
295
+ )
296
  start_frame, end_frame = st.slider(
297
  "Select Frame Range",
298
  0,
 
302
  # Visualization button logic
303
  if st.button("Visualize"):
304
  # Load filtered pose if it exists; otherwise, use the unfiltered pose
305
+
306
+ pose_bytes = get_pose_gif(
307
+ pose=filtered_pose,
308
+ step=step,
309
+ start_frame=start_frame,
310
+ end_frame=end_frame,
311
+ fps=fps,
312
+ )
313
+ if pose_bytes is not None:
314
+ st.image(pose_bytes)