Dionyssos commited on
Commit
4d8c40c
·
1 Parent(s): 929da88
Files changed (1) hide show
  1. app.py +102 -3
app.py CHANGED
@@ -68,7 +68,7 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
68
  logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
69
 
70
  return hidden_states, logits_age, logits_gender
71
-
72
  # AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel
73
 
74
  def _forward(
@@ -178,7 +178,7 @@ age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model)
178
  expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model)
179
 
180
  def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
181
-
182
  # batch audio
183
  y = expression_processor(x, sampling_rate=sampling_rate)
184
  y = y['input_values'][0]
@@ -227,7 +227,7 @@ def recognize(input_file: str) -> typing.Tuple[str, dict, str]:
227
  return process_func(signal, target_rate)
228
 
229
 
230
- def plot_expression(arousal, dominance, valence):
231
  r"""3D pixel plot of arousal, dominance, valence."""
232
  # Voxels per dimension
233
  voxels = 7
@@ -271,6 +271,105 @@ def plot_expression(arousal, dominance, valence):
271
  verticalalignment="top",
272
  )
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
 
276
  description = (
 
68
  logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
69
 
70
  return hidden_states, logits_age, logits_gender
71
+
72
  # AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel
73
 
74
  def _forward(
 
178
  expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model)
179
 
180
  def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
181
+
182
  # batch audio
183
  y = expression_processor(x, sampling_rate=sampling_rate)
184
  y = y['input_values'][0]
 
227
  return process_func(signal, target_rate)
228
 
229
 
230
+ def plot_expression_RIGID(arousal, dominance, valence):
231
  r"""3D pixel plot of arousal, dominance, valence."""
232
  # Voxels per dimension
233
  voxels = 7
 
271
  verticalalignment="top",
272
  )
273
 
274
+ COLORMAP = plt.get_cmap('coolwarm')
275
+ N_PIX = 5
276
+
277
+ matplotlib.rcParams['mathtext.fontset'] = 'stix'
278
+ matplotlib.rcParams['font.family'] = 'STIXGeneral'
279
+
280
+ def explode(data):
281
+ '''replicate 16 x 16 x 16 cube to edges array 31 x 31 x 31'''
282
+ size = np.array(data.shape)*2
283
+ data_e = np.zeros(size - 1, dtype=data.dtype)
284
+ data_e[::2, ::2, ::2] = data
285
+ return data_e
286
+
287
+
288
+ def plot_expression(arousal, dominance, valence):
289
+
290
+ '''_h = cuda tensor (N_PIX, N_PIX, N_PIX)'''
291
+
292
+ N_PIX=5
293
+ _h = np.random.rand(N_PIX, N_PIX, N_PIX) * 1e-3
294
+ adv = np.array([arousal, .994 - dominance, valence]).clip(0, .99)
295
+ arousal, dominance, valence = (adv * N_PIX).astype(np.int64) # find voxel
296
+ _h[arousal, dominance, valence] = .22
297
+
298
+
299
+
300
+
301
+ filled = np.ones((N_PIX, N_PIX, N_PIX), dtype=bool)
302
+
303
+ # upscale the above voxel image, leaving gaps
304
+ filled_2 = explode(filled)
305
+
306
+ # Shrink the gaps
307
+ x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2
308
+ x[1::2, :, :] += 1
309
+ y[:, 1::2, :] += 1
310
+ z[:, :, 1::2] += 1
311
+
312
+ ax = plt.figure().add_subplot(projection='3d')
313
+
314
+ f_2 = np.ones([2 * N_PIX - 1,
315
+ 2 * N_PIX - 1,
316
+ 2 * N_PIX - 1, 4], dtype=np.float64)
317
+ f_2[:, :, :, 3] = explode(_h)
318
+ cm = plt.get_cmap('cool')
319
+ f_2[:, :, :, :3] = cm(f_2[:, :, :, 3])[..., :3]
320
+
321
+ f_2[:, :, :, 3] = f_2[:, :, :, 3].clip(.01, .74)
322
+
323
+ print(f_2.shape, 'f_2 AAAA')
324
+ ecolors_2 = f_2
325
+
326
+ ax.voxels(x, y, z, filled_2, facecolors=f_2, edgecolors=.006 * ecolors_2)
327
+ ax.set_aspect('equal')
328
+ ax.set_zticks([0, N_PIX])
329
+ ax.set_xticks([0, N_PIX])
330
+ ax.set_yticks([0, N_PIX])
331
+
332
+ ax.set_zticklabels([f'{n/N_PIX:.2f}'[0:] for n in ax.get_zticks()])
333
+ ax.set_zlabel('valence', fontsize=10, labelpad=0)
334
+ ax.set_xticklabels([f'{n/N_PIX:.2f}' for n in ax.get_xticks()])
335
+ ax.set_xlabel('arousal', fontsize=10, labelpad=7)
336
+ # The y-axis rotation is corrected here from 275 to 90 degrees
337
+ ax.set_yticklabels([f'{1-n/N_PIX:.2f}' for n in ax.get_yticks()], rotation=90)
338
+ ax.set_ylabel('dominance', fontsize=10, labelpad=10)
339
+ ax.grid(False)
340
+
341
+
342
+
343
+
344
+ ax.plot([N_PIX, N_PIX], [0, N_PIX + .2], [N_PIX, N_PIX], 'g', linewidth=1)
345
+ ax.plot([0, N_PIX], [N_PIX, N_PIX + .24], [N_PIX, N_PIX], 'k', linewidth=1)
346
+
347
+ # Bottom face lines
348
+ # ax.plot([0, N_PIX + line_extension], [0, 0], [0, 0], 'y', linewidth=1)
349
+ # ax.plot([0, 0], [0, N_PIX + line_extension], [0, 0], 'r', linewidth=1)
350
+ # ax.plot([N_PIX, N_PIX + line_extension], [0, N_PIX], [0, 0], 'm', linewidth=1)
351
+ # ax.plot([0, N_PIX], [N_PIX, N_PIX + line_extension], [0, 0], 'c', linewidth=1)
352
+
353
+ # Vertical lines
354
+ # ax.plot([0, 0], [0, 0], [0, N_PIX + line_extension], 'b', linewidth=1)
355
+ # ax.plot([N_PIX, N_PIX], [0, 0], [0, N_PIX + line_extension], 'w', linewidth=1)
356
+ # ax.plot([N_PIX, N_PIX], [N_PIX, N_PIX], [0, N_PIX + line_extension], 'orange', linewidth=1)
357
+ # ax.plot([0, 0], [N_PIX, N_PIX], [0, N_PIX + line_extension], 'lime', linewidth=1)
358
+
359
+ # # Missing lines on the top face
360
+ ax.plot([0, 0], [0, N_PIX], [N_PIX, N_PIX], 'darkred', linewidth=1)
361
+ ax.plot([0, N_PIX], [0, 0], [N_PIX, N_PIX], 'darkblue', linewidth=1)
362
+
363
+ # Set pane colors after plotting the lines
364
+ ax.w_xaxis.set_pane_color((0.8, 0.8, 0.8, 0.5))
365
+ ax.w_yaxis.set_pane_color((0.8, 0.8, 0.8, 0.5))
366
+ ax.w_zaxis.set_pane_color((0.8, 0.8, 0.8, 0.0))
367
+
368
+ # Restore the limits to prevent the plot from expanding
369
+ ax.set_xlim(0, N_PIX)
370
+ ax.set_ylim(0, N_PIX)
371
+ ax.set_zlim(0, N_PIX)
372
+ # ------
373
 
374
 
375
  description = (