Nick021402 commited on
Commit
5a8efbe
·
verified ·
1 Parent(s): 441e8aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -26
app.py CHANGED
@@ -9,7 +9,8 @@ from transformers import (
9
  AutoModelForAudioClassification,
10
  AutoFeatureExtractor,
11
  T5ForConditionalGeneration,
12
- T5Tokenizer
 
13
  )
14
  import librosa
15
  import warnings
@@ -23,9 +24,16 @@ stt_tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
23
  stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
24
  stt_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
25
 
26
- # Emotion Recognition Model
27
- emotion_feature_extractor = AutoFeatureExtractor.from_pretrained("ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition")
28
- emotion_model = AutoModelForAudioClassification.from_pretrained("ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition")
 
 
 
 
 
 
 
29
 
30
  # Personality Generation Model
31
  personality_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
@@ -33,15 +41,15 @@ personality_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-b
33
 
34
  print("Models loaded successfully!")
35
 
36
- # Emotion labels mapping
37
  EMOTION_LABELS = {
38
  0: "angry",
39
- 1: "disgust",
40
- 2: "fear",
41
- 3: "happy",
42
- 4: "neutral",
43
- 5: "sad",
44
- 6: "surprise"
45
  }
46
 
47
  def preprocess_audio(audio_path, target_sr=16000):
@@ -82,29 +90,47 @@ def transcribe_audio(audio_path):
82
  return f"Transcription error: {str(e)}"
83
 
84
  def detect_emotion(audio_path):
85
- """Detect emotion from audio using specialized model"""
86
  try:
87
  audio, sr = preprocess_audio(audio_path)
88
  if audio is None:
89
  return "Error: Could not process audio file", 0.0
90
 
91
- # Extract features for emotion model
92
- inputs = emotion_feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
93
-
94
- # Get emotion predictions
95
- with torch.no_grad():
96
- outputs = emotion_model(**inputs)
97
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
98
-
99
- # Get the most likely emotion
100
- emotion_id = torch.argmax(predictions, dim=-1).item()
101
- confidence = torch.max(predictions).item()
102
-
103
- emotion_label = EMOTION_LABELS.get(emotion_id, "unknown")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  return emotion_label, confidence
106
  except Exception as e:
107
- return f"Emotion detection error: {str(e)}", 0.0
108
 
109
  def generate_personality(transcription, emotion, confidence):
110
  """Generate personality description using FLAN-T5"""
 
9
  AutoModelForAudioClassification,
10
  AutoFeatureExtractor,
11
  T5ForConditionalGeneration,
12
+ T5Tokenizer,
13
+ Wav2Vec2ForSequenceClassification
14
  )
15
  import librosa
16
  import warnings
 
24
  stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
25
  stt_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
26
 
27
+ # Emotion Recognition Model - using a more reliable model
28
+ try:
29
+ from transformers import Wav2Vec2ForSequenceClassification
30
+ emotion_feature_extractor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
31
+ emotion_model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er")
32
+ except:
33
+ # Fallback to a simpler approach using audio features
34
+ emotion_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
35
+ emotion_model = None
36
+ print("Using fallback emotion detection method")
37
 
38
  # Personality Generation Model
39
  personality_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
 
41
 
42
  print("Models loaded successfully!")
43
 
44
+ # Emotion labels mapping (updated for broader coverage)
45
  EMOTION_LABELS = {
46
  0: "angry",
47
+ 1: "happy",
48
+ 2: "sad",
49
+ 3: "neutral",
50
+ 4: "excited",
51
+ 5: "calm",
52
+ 6: "surprised"
53
  }
54
 
55
  def preprocess_audio(audio_path, target_sr=16000):
 
90
  return f"Transcription error: {str(e)}"
91
 
92
  def detect_emotion(audio_path):
93
+ """Detect emotion from audio using audio features analysis"""
94
  try:
95
  audio, sr = preprocess_audio(audio_path)
96
  if audio is None:
97
  return "Error: Could not process audio file", 0.0
98
 
99
+ if emotion_model is not None:
100
+ # Use the wav2vec2 emotion model if available
101
+ inputs = emotion_feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding=True)
102
+
103
+ with torch.no_grad():
104
+ outputs = emotion_model(**inputs)
105
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
106
+
107
+ emotion_id = torch.argmax(predictions, dim=-1).item()
108
+ confidence = torch.max(predictions).item()
109
+ emotion_label = EMOTION_LABELS.get(emotion_id, "neutral")
110
+ else:
111
+ # Fallback: Simple audio feature-based emotion detection
112
+ # Analyze audio characteristics
113
+ rms_energy = np.sqrt(np.mean(audio**2))
114
+ zero_crossing_rate = np.mean(librosa.feature.zero_crossing_rate(audio)[0])
115
+ spectral_centroid = np.mean(librosa.feature.spectral_centroid(audio, sr=sr)[0])
116
+
117
+ # Simple heuristic-based emotion classification
118
+ if rms_energy > 0.02 and zero_crossing_rate > 0.1:
119
+ emotion_label = "excited"
120
+ confidence = 0.75
121
+ elif rms_energy < 0.005:
122
+ emotion_label = "calm"
123
+ confidence = 0.70
124
+ elif spectral_centroid > 2000:
125
+ emotion_label = "happy"
126
+ confidence = 0.65
127
+ else:
128
+ emotion_label = "neutral"
129
+ confidence = 0.60
130
 
131
  return emotion_label, confidence
132
  except Exception as e:
133
+ return "neutral", 0.50 # Default fallback
134
 
135
  def generate_personality(transcription, emotion, confidence):
136
  """Generate personality description using FLAN-T5"""