Colin Leong commited on
Commit
6e290b7
·
1 Parent(s): 4d840a0

Streamlit does not work intuitively

Browse files
Files changed (1) hide show
  1. app.py +22 -146
app.py CHANGED
@@ -9,7 +9,6 @@ from pyzstd import decompress
9
  from PIL import Image
10
  import cv2
11
  import mediapipe as mp
12
- import torch
13
 
14
  mp_holistic = mp.solutions.holistic
15
  FACEMESH_CONTOURS_POINTS = [
@@ -54,35 +53,6 @@ def pose_hide_legs(pose):
54
  raise ValueError("Unknown pose header schema for hiding legs")
55
 
56
 
57
- def preprocess_pose(pose):
58
- pose = pose.get_components(
59
- [
60
- "POSE_LANDMARKS",
61
- "FACE_LANDMARKS",
62
- "LEFT_HAND_LANDMARKS",
63
- "RIGHT_HAND_LANDMARKS",
64
- ],
65
- {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS},
66
- )
67
-
68
- pose = pose.normalize(pose_normalization_info(pose.header))
69
- pose = pose_hide_legs(pose)
70
-
71
- # from sign_vq.data.normalize import pre_process_mediapipe, normalize_mean_std
72
- # from pose_anonymization.appearance import remove_appearance
73
-
74
- # pose = remove_appearance(pose)
75
- # pose = pre_process_mediapipe(pose)
76
- # pose = normalize_mean_std(pose)
77
-
78
- feat = np.nan_to_num(pose.body.data)
79
- feat = feat.reshape(feat.shape[0], -1)
80
-
81
- pose_frames = torch.from_numpy(np.expand_dims(feat, axis=0)).float()
82
-
83
- return pose_frames
84
-
85
-
86
  # @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
87
  def load_pose(uploaded_file: UploadedFile) -> Pose:
88
 
@@ -129,147 +99,53 @@ if uploaded_file is not None:
129
  with st.spinner(f"Loading {uploaded_file.name}"):
130
  pose = load_pose(uploaded_file)
131
  frames, images = get_pose_frames(pose=pose)
132
- st.success("done loading!")
133
- # st.write(f"pose shape: {pose.body.data.shape}")
134
-
135
- header = pose.header
136
  st.write("### File Info")
137
  with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
138
-
139
- st.write(header)
140
- with st.expander(f"Show body information from {uploaded_file.name}"):
141
- st.write(pose.body)
142
- # st.write(pose.body.data.shape)
143
- # st.write(pose.body.fps)
144
 
145
  st.write(f"### Selection")
146
-
147
- components = pose.header.components
148
-
149
- component_names = [component.name for component in components]
150
- chosen_component_names = component_names
151
-
152
  component_selection = st.radio(
153
  "How to select components?", options=["manual", "signclip"]
154
  )
 
155
  if component_selection == "manual":
156
- st.write(f"### Component selection: ")
157
- chosen_component_names = st.pills(
158
- "Components to visualize",
159
- options=component_names,
160
- selection_mode="multi",
161
- default=component_names,
162
  )
163
-
164
- # st.write(chosen_component_names)
165
-
166
- st.write("### Point selection:")
167
- point_names = []
168
- new_chosen_components = []
169
- points_dict = {}
170
- for component in pose.header.components:
171
- with st.expander(f"points for {component.name}"):
172
-
173
- if component.name in chosen_component_names:
174
-
175
- st.write(f"#### {component.name}")
176
- selected_points = st.multiselect(
177
- f"points for component {component.name}:",
178
- options=component.points,
179
- default=component.points,
180
- )
181
- if selected_points == component.points:
182
- st.write(
183
- f"All selected, no need to add a points dict entry for {component.name}"
184
- )
185
- else:
186
- st.write(f"Adding dictionary for {component.name}")
187
- points_dict[component.name] = selected_points
188
-
189
- # selected_points = st.multiselect("points to visualize", options=point_names, default=point_names)
190
  if chosen_component_names:
191
-
192
- if not points_dict:
193
- points_dict = None
194
- # else:
195
- # st.write(points_dict)
196
- # st.write(chosen_component_names)
197
-
198
- pose = pose.get_components(chosen_component_names, points=points_dict)
199
- # st.write(pose.header)
200
 
201
  elif component_selection == "signclip":
202
- st.write("Selected landmarks used for SignCLIP. (Face countours only)")
203
  pose = pose.get_components(
204
- [
205
- "POSE_LANDMARKS",
206
- "FACE_LANDMARKS",
207
- "LEFT_HAND_LANDMARKS",
208
- "RIGHT_HAND_LANDMARKS",
209
- ],
210
- {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS},
211
  )
212
-
213
- # pose = pose.normalize(pose_normalization_info(pose.header)) Visualization goes blank
214
  pose = pose_hide_legs(pose)
215
- with st.expander("Show facemesh contour points:"):
216
- st.write(f"{FACEMESH_CONTOURS_POINTS}")
217
- with st.expander(f"Show header:"):
218
- st.write(pose.header)
219
- # st.write(f"signclip selected, new header:")
220
- # st.write(pose.body.data.shape)
221
- # st.write(pose.header)
222
- else:
223
- pass
224
-
225
- filtered = st.button(f"Filter Components/Points")
226
- if filtered:
227
-
228
-
229
 
 
 
230
  st.write("### Filtered .pose file")
231
  with st.expander("Show header"):
232
- st.write(pose.header)
233
-
234
- with st.expander(f"Show body"):
235
  st.write(pose.body)
236
-
237
  pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
238
- with pose_file_out.open("wb") as f:
239
  pose.write(f)
240
 
241
- with pose_file_out.open("rb") as f:
242
  st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
243
 
244
- st.write(f"### Visualization")
245
- width = st.select_slider(
246
- "select width of images",
247
- list(range(1, pose.header.dimensions.width + 1)),
248
- value=pose.header.dimensions.width / 2,
249
- )
250
- step = st.select_slider(
251
- "Step value to select every nth image", list(range(1, len(frames))), value=1
252
- )
253
- fps = st.slider(
254
- "fps for visualization: ",
255
- min_value=1.0,
256
- max_value=pose.body.fps,
257
- value=pose.body.fps,
258
- )
259
-
260
-
261
-
262
-
263
- visualize_clicked = st.button(f"Visualize!")
264
-
265
-
266
- if visualize_clicked:
267
-
268
- st.write(f"Generating gif...")
269
-
270
- # st.write(pose.body.data.shape)
271
 
272
- st.image(get_pose_gif(pose=pose, step=step, fps=fps))
273
 
274
 
275
 
 
9
  from PIL import Image
10
  import cv2
11
  import mediapipe as mp
 
12
 
13
  mp_holistic = mp.solutions.holistic
14
  FACEMESH_CONTOURS_POINTS = [
 
53
  raise ValueError("Unknown pose header schema for hiding legs")
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # @st.cache_data(hash_funcs={UploadedFile: lambda p: str(p.name)})
57
  def load_pose(uploaded_file: UploadedFile) -> Pose:
58
 
 
99
  with st.spinner(f"Loading {uploaded_file.name}"):
100
  pose = load_pose(uploaded_file)
101
  frames, images = get_pose_frames(pose=pose)
102
+ st.success("Done loading!")
103
+
 
 
104
  st.write("### File Info")
105
  with st.expander(f"Show full Pose-format header from {uploaded_file.name}"):
106
+ st.write(pose.header)
 
 
 
 
 
107
 
108
  st.write(f"### Selection")
 
 
 
 
 
 
109
  component_selection = st.radio(
110
  "How to select components?", options=["manual", "signclip"]
111
  )
112
+
113
  if component_selection == "manual":
114
+ chosen_component_names = st.multiselect(
115
+ "Select components to visualize", options=[c.name for c in pose.header.components]
 
 
 
 
116
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  if chosen_component_names:
118
+ pose = pose.get_components(chosen_component_names)
 
 
 
 
 
 
 
 
119
 
120
  elif component_selection == "signclip":
121
+ st.write("Selected landmarks used for SignCLIP.")
122
  pose = pose.get_components(
123
+ ["POSE_LANDMARKS", "FACE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]
 
 
 
 
 
 
124
  )
 
 
125
  pose = pose_hide_legs(pose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # Filter button logic
128
+ if st.button("Filter Components/Points"):
129
  st.write("### Filtered .pose file")
130
  with st.expander("Show header"):
131
+ st.write(pose.header)
132
+ with st.expander("Show body"):
 
133
  st.write(pose.body)
134
+
135
  pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
136
+ with pose_file_out.open("wb") as f:
137
  pose.write(f)
138
 
139
+ with pose_file_out.open("rb") as f:
140
  st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
141
 
142
+ # Visualization button logic
143
+ if st.button("Visualize"):
144
+ st.write("### Visualization")
145
+ step = st.select_slider("Step value to select every nth image", list(range(1, len(frames))), value=1)
146
+ fps = st.slider("FPS for visualization", min_value=1.0, max_value=pose.body.fps, value=pose.body.fps)
147
+ st.image(get_pose_gif(pose=pose, step=step, fps=fps))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
149
 
150
 
151