Spaces:
Sleeping
Sleeping
fx
Browse files
app.py
CHANGED
@@ -68,7 +68,7 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
|
|
68 |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
|
69 |
|
70 |
return hidden_states, logits_age, logits_gender
|
71 |
-
|
72 |
# AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel
|
73 |
|
74 |
def _forward(
|
@@ -178,7 +178,7 @@ age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model)
|
|
178 |
expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model)
|
179 |
|
180 |
def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
|
181 |
-
|
182 |
# batch audio
|
183 |
y = expression_processor(x, sampling_rate=sampling_rate)
|
184 |
y = y['input_values'][0]
|
@@ -227,7 +227,7 @@ def recognize(input_file: str) -> typing.Tuple[str, dict, str]:
|
|
227 |
return process_func(signal, target_rate)
|
228 |
|
229 |
|
230 |
-
def
|
231 |
r"""3D pixel plot of arousal, dominance, valence."""
|
232 |
# Voxels per dimension
|
233 |
voxels = 7
|
@@ -271,6 +271,105 @@ def plot_expression(arousal, dominance, valence):
|
|
271 |
verticalalignment="top",
|
272 |
)
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
|
276 |
description = (
|
|
|
68 |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
|
69 |
|
70 |
return hidden_states, logits_age, logits_gender
|
71 |
+
|
72 |
# AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel
|
73 |
|
74 |
def _forward(
|
|
|
178 |
expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model)
|
179 |
|
180 |
def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
|
181 |
+
|
182 |
# batch audio
|
183 |
y = expression_processor(x, sampling_rate=sampling_rate)
|
184 |
y = y['input_values'][0]
|
|
|
227 |
return process_func(signal, target_rate)
|
228 |
|
229 |
|
230 |
+
def plot_expression_RIGID(arousal, dominance, valence):
|
231 |
r"""3D pixel plot of arousal, dominance, valence."""
|
232 |
# Voxels per dimension
|
233 |
voxels = 7
|
|
|
271 |
verticalalignment="top",
|
272 |
)
|
273 |
|
274 |
+
COLORMAP = plt.get_cmap('coolwarm')
|
275 |
+
N_PIX = 5
|
276 |
+
|
277 |
+
matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
278 |
+
matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
279 |
+
|
280 |
+
def explode(data):
|
281 |
+
'''replicate 16 x 16 x 16 cube to edges array 31 x 31 x 31'''
|
282 |
+
size = np.array(data.shape)*2
|
283 |
+
data_e = np.zeros(size - 1, dtype=data.dtype)
|
284 |
+
data_e[::2, ::2, ::2] = data
|
285 |
+
return data_e
|
286 |
+
|
287 |
+
|
288 |
+
def plot_expression(arousal, dominance, valence):
|
289 |
+
|
290 |
+
'''_h = cuda tensor (N_PIX, N_PIX, N_PIX)'''
|
291 |
+
|
292 |
+
N_PIX=5
|
293 |
+
_h = np.random.rand(N_PIX, N_PIX, N_PIX) * 1e-3
|
294 |
+
adv = np.array([arousal, .994 - dominance, valence]).clip(0, .99)
|
295 |
+
arousal, dominance, valence = (adv * N_PIX).astype(np.int64) # find voxel
|
296 |
+
_h[arousal, dominance, valence] = .22
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
filled = np.ones((N_PIX, N_PIX, N_PIX), dtype=bool)
|
302 |
+
|
303 |
+
# upscale the above voxel image, leaving gaps
|
304 |
+
filled_2 = explode(filled)
|
305 |
+
|
306 |
+
# Shrink the gaps
|
307 |
+
x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2
|
308 |
+
x[1::2, :, :] += 1
|
309 |
+
y[:, 1::2, :] += 1
|
310 |
+
z[:, :, 1::2] += 1
|
311 |
+
|
312 |
+
ax = plt.figure().add_subplot(projection='3d')
|
313 |
+
|
314 |
+
f_2 = np.ones([2 * N_PIX - 1,
|
315 |
+
2 * N_PIX - 1,
|
316 |
+
2 * N_PIX - 1, 4], dtype=np.float64)
|
317 |
+
f_2[:, :, :, 3] = explode(_h)
|
318 |
+
cm = plt.get_cmap('cool')
|
319 |
+
f_2[:, :, :, :3] = cm(f_2[:, :, :, 3])[..., :3]
|
320 |
+
|
321 |
+
f_2[:, :, :, 3] = f_2[:, :, :, 3].clip(.01, .74)
|
322 |
+
|
323 |
+
print(f_2.shape, 'f_2 AAAA')
|
324 |
+
ecolors_2 = f_2
|
325 |
+
|
326 |
+
ax.voxels(x, y, z, filled_2, facecolors=f_2, edgecolors=.006 * ecolors_2)
|
327 |
+
ax.set_aspect('equal')
|
328 |
+
ax.set_zticks([0, N_PIX])
|
329 |
+
ax.set_xticks([0, N_PIX])
|
330 |
+
ax.set_yticks([0, N_PIX])
|
331 |
+
|
332 |
+
ax.set_zticklabels([f'{n/N_PIX:.2f}'[0:] for n in ax.get_zticks()])
|
333 |
+
ax.set_zlabel('valence', fontsize=10, labelpad=0)
|
334 |
+
ax.set_xticklabels([f'{n/N_PIX:.2f}' for n in ax.get_xticks()])
|
335 |
+
ax.set_xlabel('arousal', fontsize=10, labelpad=7)
|
336 |
+
# The y-axis rotation is corrected here from 275 to 90 degrees
|
337 |
+
ax.set_yticklabels([f'{1-n/N_PIX:.2f}' for n in ax.get_yticks()], rotation=90)
|
338 |
+
ax.set_ylabel('dominance', fontsize=10, labelpad=10)
|
339 |
+
ax.grid(False)
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
ax.plot([N_PIX, N_PIX], [0, N_PIX + .2], [N_PIX, N_PIX], 'g', linewidth=1)
|
345 |
+
ax.plot([0, N_PIX], [N_PIX, N_PIX + .24], [N_PIX, N_PIX], 'k', linewidth=1)
|
346 |
+
|
347 |
+
# Bottom face lines
|
348 |
+
# ax.plot([0, N_PIX + line_extension], [0, 0], [0, 0], 'y', linewidth=1)
|
349 |
+
# ax.plot([0, 0], [0, N_PIX + line_extension], [0, 0], 'r', linewidth=1)
|
350 |
+
# ax.plot([N_PIX, N_PIX + line_extension], [0, N_PIX], [0, 0], 'm', linewidth=1)
|
351 |
+
# ax.plot([0, N_PIX], [N_PIX, N_PIX + line_extension], [0, 0], 'c', linewidth=1)
|
352 |
+
|
353 |
+
# Vertical lines
|
354 |
+
# ax.plot([0, 0], [0, 0], [0, N_PIX + line_extension], 'b', linewidth=1)
|
355 |
+
# ax.plot([N_PIX, N_PIX], [0, 0], [0, N_PIX + line_extension], 'w', linewidth=1)
|
356 |
+
# ax.plot([N_PIX, N_PIX], [N_PIX, N_PIX], [0, N_PIX + line_extension], 'orange', linewidth=1)
|
357 |
+
# ax.plot([0, 0], [N_PIX, N_PIX], [0, N_PIX + line_extension], 'lime', linewidth=1)
|
358 |
+
|
359 |
+
# # Missing lines on the top face
|
360 |
+
ax.plot([0, 0], [0, N_PIX], [N_PIX, N_PIX], 'darkred', linewidth=1)
|
361 |
+
ax.plot([0, N_PIX], [0, 0], [N_PIX, N_PIX], 'darkblue', linewidth=1)
|
362 |
+
|
363 |
+
# Set pane colors after plotting the lines
|
364 |
+
ax.w_xaxis.set_pane_color((0.8, 0.8, 0.8, 0.5))
|
365 |
+
ax.w_yaxis.set_pane_color((0.8, 0.8, 0.8, 0.5))
|
366 |
+
ax.w_zaxis.set_pane_color((0.8, 0.8, 0.8, 0.0))
|
367 |
+
|
368 |
+
# Restore the limits to prevent the plot from expanding
|
369 |
+
ax.set_xlim(0, N_PIX)
|
370 |
+
ax.set_ylim(0, N_PIX)
|
371 |
+
ax.set_zlim(0, N_PIX)
|
372 |
+
# ------
|
373 |
|
374 |
|
375 |
description = (
|