Dionyssos commited on
Commit
fb65e18
·
1 Parent(s): ad493ec
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +108 -39
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Speech analysis
3
- emoji:
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
 
1
  ---
2
  title: Speech analysis
3
+ emoji: 🌀
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import typing
2
-
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import numpy as np
@@ -58,16 +58,95 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
58
 
59
  def forward(
60
  self,
61
- input_values,
62
  ):
63
 
64
- outputs = self.wav2vec2(input_values)
65
- hidden_states = outputs[0]
66
  hidden_states = torch.mean(hidden_states, dim=1)
67
  logits_age = self.age(hidden_states)
68
  logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
69
 
70
  return hidden_states, logits_age, logits_gender
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  class ExpressionHead(nn.Module):
@@ -106,12 +185,11 @@ class ExpressionModel(Wav2Vec2PreTrainedModel):
106
  self.init_weights()
107
 
108
  def forward(self, input_values):
109
- outputs = self.wav2vec2(input_values)
110
- hidden_states = outputs[0]
111
  hidden_states = torch.mean(hidden_states, dim=1)
112
  logits = self.classifier(hidden_states)
113
 
114
- return hidden_states, logits
115
 
116
 
117
  # Load models from hub
@@ -120,46 +198,37 @@ age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name)
120
  expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name)
121
  expression_model = ExpressionModel.from_pretrained(expression_model_name)
122
 
 
 
 
 
123
 
124
  def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
125
- r"""Predict age and gender or extract embeddings from raw audio signal."""
126
- # run through processor to normalize signal
127
- # always returns a batch, so we just get the first entry
128
- # then we put it on the device
129
- results = []
130
- for processor, model in zip(
131
- [age_gender_processor, expression_processor],
132
- [age_gender_model, expression_model],
133
- ):
134
- y = processor(x, sampling_rate=sampling_rate)
135
- y = y['input_values'][0]
136
- y = y.reshape(1, -1)
137
- y = torch.from_numpy(y).to(device)
138
-
139
- # run through model
140
- with torch.no_grad():
141
- y = model(y)
142
- if len(y) == 3:
143
- # Age-gender model
144
- y = torch.hstack([y[1], y[2]])
145
- else:
146
- # Expression model
147
- y = y[1]
148
-
149
- # convert to numpy
150
- y = y.detach().cpu().numpy()
151
- results.append(y[0])
152
 
153
  # Plot A/D/V values
154
- plot_expression(results[1][0], results[1][1], results[1][2])
 
 
155
  expression_file = "expression.png"
156
  plt.savefig(expression_file)
157
  return (
158
- f"{round(100 * results[0][0])} years", # age
159
  {
160
- "female": results[0][1],
161
- "male": results[0][2],
162
- "child": results[0][3],
163
  },
164
  expression_file,
165
  )
 
1
  import typing
2
+ import types # fusion of forward() of Wav2Vec2
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
58
 
59
  def forward(
60
  self,
61
+ frozen_cnn7,
62
  ):
63
 
64
+ hidden_states = self.wav2vec2(frozen_cnn7=frozen_cnn7) # runs only Transformer layers
65
+
66
  hidden_states = torch.mean(hidden_states, dim=1)
67
  logits_age = self.age(hidden_states)
68
  logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
69
 
70
  return hidden_states, logits_age, logits_gender
71
+
72
+
73
+
74
+ # == Fusion = Define Age Wav2Vec2Model's forward to accept already computed CNN7 features from Emotion
75
+ def _forward(
76
+ self,
77
+ extract_features,
78
+ attention_mask=None):
79
+ # extract_features : CNN7 fetures of wav2vec2 as they are calc. from CNN7 feature extractor
80
+
81
+
82
+ if attention_mask is not None:
83
+ # compute reduced attention_mask corresponding to feature vectors
84
+ attention_mask = self._get_feature_vector_attention_mask(
85
+ extract_features.shape[1], attention_mask, add_adapter=False
86
+ )
87
+
88
+ hidden_states, extract_features = self.feature_projection(extract_features)
89
+ hidden_states = self._mask_hidden_states(
90
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
91
+ )
92
+
93
+ encoder_outputs = self.encoder(
94
+ hidden_states,
95
+ attention_mask=attention_mask,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=output_hidden_states,
98
+ return_dict=return_dict,
99
+ )
100
+
101
+ hidden_states = encoder_outputs[0]
102
+
103
+ if self.adapter is not None:
104
+ raise ValueError
105
+ hidden_states = self.adapter(hidden_states)
106
+
107
+ return hidden_states
108
+ # ===============================================
109
+
110
+
111
+ # ================== Foward & CNN features
112
+ def _forward_and_cnn7(
113
+ self,
114
+ input_values,
115
+ attention_mask=None
116
+ ):
117
+
118
+
119
+ frozen_cnn7 = self.feature_extractor(input_values)
120
+ frozen_cnn7 = frozen_cnn7.transpose(1, 2)
121
+
122
+ if attention_mask is not None:
123
+ # compute reduced attention_mask corresponding to feature vectors
124
+ attention_mask = self._get_feature_vector_attention_mask(
125
+ frozen_cnn7.shape[1], attention_mask, add_adapter=False
126
+ )
127
+
128
+ hidden_states, extract_features = self.feature_projection(frozen_cnn7) # grad=True non frozen
129
+ hidden_states = self._mask_hidden_states(
130
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
131
+ )
132
+
133
+ encoder_outputs = self.encoder(
134
+ hidden_states,
135
+ attention_mask=attention_mask,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+
141
+ hidden_states = encoder_outputs[0]
142
+
143
+ if self.adapter is not None:
144
+ raise ValueError
145
+ hidden_states = self.adapter(hidden_states)
146
+
147
+ return hidden_states, frozen_cnn7 # feature_projection is trainable thus we are unable to use the projected hidden states from official wav2vev2.forward
148
+
149
+ # =============================
150
 
151
 
152
  class ExpressionHead(nn.Module):
 
185
  self.init_weights()
186
 
187
  def forward(self, input_values):
188
+ hidden_states, frozen_cnn7 = self.wav2vec2(input_values)
 
189
  hidden_states = torch.mean(hidden_states, dim=1)
190
  logits = self.classifier(hidden_states)
191
 
192
+ return hidden_states, logits, frozen_cnn7
193
 
194
 
195
  # Load models from hub
 
198
  expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name)
199
  expression_model = ExpressionModel.from_pretrained(expression_model_name)
200
 
201
+ # Emotion Calc. CNN features
202
+
203
+ age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model)
204
+ expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model)
205
 
206
  def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]:
207
+
208
+ # batch audio
209
+ y = expression_processor(x, sampling_rate=sampling_rate)
210
+ y = y['input_values'][0]
211
+ y = y.reshape(1, -1)
212
+ y = torch.from_numpy(y).to(device)
213
+
214
+ # run through expression model
215
+ with torch.no_grad():
216
+ _, logits_expression, frozen_cnn7 = expression_model(y)
217
+
218
+ _, logits_age, logits_gender = age_gender_model(frozen_cnn7=frozen_cnn7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  # Plot A/D/V values
221
+ plot_expression(logits_expression[0, 0].item(), # implicit detach().cpu().numpy()
222
+ logits_expression[0, 1].item(),
223
+ logits_expression[0, 2].item())
224
  expression_file = "expression.png"
225
  plt.savefig(expression_file)
226
  return (
227
+ f"{round(100 * logits_age[0, 0].item())} years", # age
228
  {
229
+ "female": logits_gender[0, 0].item(),
230
+ "male": logits_gender[0, 1].item(),
231
+ "child": logits_gender[0, 2].item(),
232
  },
233
  expression_file,
234
  )