Dionyssos commited on
Commit
2bde17b
·
1 Parent(s): 6d576da
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -74,18 +74,17 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
74
  # Fusion = AgeWav2Vec2Model forward() will accept already computed CNN7 features from ExpressioNmodel forward()
75
  def _forward(
76
  self,
77
- extract_features,
78
  attention_mask=None):
79
- # extract_features : CNN7 fetures of wav2vec2 as they are calc. from CNN7 feature extractor
80
 
81
 
82
  if attention_mask is not None:
83
  # compute reduced attention_mask corresponding to feature vectors
84
  attention_mask = self._get_feature_vector_attention_mask(
85
- extract_features.shape[1], attention_mask, add_adapter=False
86
  )
87
 
88
- hidden_states, extract_features = self.feature_projection(extract_features)
89
  hidden_states = self._mask_hidden_states(
90
  hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
91
  )
@@ -121,7 +120,7 @@ def _forward_and_cnn7(
121
  frozen_cnn7.shape[1], attention_mask, add_adapter=False
122
  )
123
 
124
- hidden_states, extract_features = self.feature_projection(frozen_cnn7) # grad=True non frozen
125
  hidden_states = self._mask_hidden_states(
126
  hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
127
  )
 
74
  # Fusion = AgeWav2Vec2Model forward() will accept already computed CNN7 features from ExpressioNmodel forward()
75
  def _forward(
76
  self,
77
+ frozen_cnn7=None, # CNN7 fetures of wav2vec2 calc. from CNN7 feature extractor (once)
78
  attention_mask=None):
 
79
 
80
 
81
  if attention_mask is not None:
82
  # compute reduced attention_mask corresponding to feature vectors
83
  attention_mask = self._get_feature_vector_attention_mask(
84
+ frozen_cnn7.shape[1], attention_mask, add_adapter=False
85
  )
86
 
87
+ hidden_states, _ = self.feature_projection(frozen_cnn7)
88
  hidden_states = self._mask_hidden_states(
89
  hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
90
  )
 
120
  frozen_cnn7.shape[1], attention_mask, add_adapter=False
121
  )
122
 
123
+ hidden_states, _ = self.feature_projection(frozen_cnn7) # grad=True non frozen
124
  hidden_states = self._mask_hidden_states(
125
  hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
126
  )