Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import streamlit as st
|
4 |
import tempfile
|
@@ -32,24 +31,46 @@ st.set_page_config(layout="wide", page_title="Voice Based Sentiment Analysis")
|
|
32 |
st.title("π Voice Based Sentiment Analysis")
|
33 |
st.write("Detect emotions, sentiment, and sarcasm from your voice with state-of-the-art accuracy using OpenAI Whisper.")
|
34 |
|
|
|
|
|
35 |
# Emotion Detection Function
|
36 |
@st.cache_resource
|
37 |
def get_emotion_classifier():
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def perform_emotion_detection(text):
|
44 |
try:
|
45 |
if not text or len(text.strip()) < 3:
|
46 |
return {}, "neutral", {}, "NEUTRAL"
|
47 |
-
|
48 |
emotion_classifier = get_emotion_classifier()
|
49 |
emotion_results = emotion_classifier(text)[0]
|
50 |
-
|
51 |
emotion_map = {
|
52 |
-
"admiration": "π€©", "amusement": "π", "anger": "π‘", "annoyance": "π",
|
53 |
"approval": "π", "caring": "π€", "confusion": "π", "curiosity": "π§",
|
54 |
"desire": "π", "disappointment": "π", "disapproval": "π", "disgust": "π€’",
|
55 |
"embarrassment": "π³", "excitement": "π€©", "fear": "π¨", "gratitude": "π",
|
@@ -57,29 +78,29 @@ def perform_emotion_detection(text):
|
|
57 |
"optimism": "π", "pride": "π", "realization": "π‘", "relief": "π",
|
58 |
"remorse": "π", "sadness": "π", "surprise": "π²", "neutral": "π"
|
59 |
}
|
60 |
-
|
61 |
-
positive_emotions = ["admiration", "amusement", "approval", "caring", "desire",
|
62 |
-
|
63 |
negative_emotions = ["anger", "annoyance", "disappointment", "disapproval", "disgust",
|
64 |
-
|
65 |
neutral_emotions = ["confusion", "curiosity", "realization", "surprise", "neutral"]
|
66 |
-
|
67 |
# Fix 1: Create a clean emotions dictionary from results
|
68 |
emotions_dict = {}
|
69 |
for result in emotion_results:
|
70 |
emotions_dict[result['label']] = result['score']
|
71 |
-
|
72 |
# Fix 2: Filter out very low scores (below threshold)
|
73 |
-
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.
|
74 |
-
|
75 |
# If filtered dictionary is empty, fall back to original
|
76 |
if not filtered_emotions:
|
77 |
filtered_emotions = emotions_dict
|
78 |
-
|
79 |
# Fix 3: Make sure we properly find the top emotion
|
80 |
top_emotion = max(filtered_emotions, key=filtered_emotions.get)
|
81 |
top_score = filtered_emotions[top_emotion]
|
82 |
-
|
83 |
# Fix 4: More robust sentiment assignment
|
84 |
if top_emotion in positive_emotions:
|
85 |
sentiment = "POSITIVE"
|
@@ -88,12 +109,12 @@ def perform_emotion_detection(text):
|
|
88 |
else:
|
89 |
# If the top emotion is neutral but there are strong competing emotions, use them
|
90 |
competing_emotions = sorted(filtered_emotions.items(), key=lambda x: x[1], reverse=True)[:3]
|
91 |
-
|
92 |
# Check if there's a close second non-neutral emotion
|
93 |
if len(competing_emotions) > 1:
|
94 |
-
if (competing_emotions[0][0] in neutral_emotions and
|
95 |
-
|
96 |
-
|
97 |
# Use the second strongest emotion instead
|
98 |
top_emotion = competing_emotions[1][0]
|
99 |
if top_emotion in positive_emotions:
|
@@ -106,33 +127,55 @@ def perform_emotion_detection(text):
|
|
106 |
sentiment = "NEUTRAL"
|
107 |
else:
|
108 |
sentiment = "NEUTRAL"
|
109 |
-
|
110 |
# Log for debugging
|
111 |
print(f"Text: {text[:50]}...")
|
112 |
print(f"Top 3 emotions: {sorted(filtered_emotions.items(), key=lambda x: x[1], reverse=True)[:3]}")
|
113 |
print(f"Selected top emotion: {top_emotion} ({filtered_emotions.get(top_emotion, 0):.3f})")
|
114 |
print(f"Sentiment determined: {sentiment}")
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
return emotions_dict, top_emotion, emotion_map, sentiment
|
117 |
except Exception as e:
|
118 |
st.error(f"Emotion detection failed: {str(e)}")
|
119 |
print(f"Exception in emotion detection: {str(e)}")
|
120 |
return {}, "neutral", {}, "NEUTRAL"
|
121 |
|
|
|
122 |
# Sarcasm Detection Function
|
123 |
@st.cache_resource
|
124 |
def get_sarcasm_classifier():
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
def perform_sarcasm_detection(text):
|
131 |
try:
|
132 |
if not text or len(text.strip()) < 3:
|
133 |
return False, 0.0
|
134 |
-
|
135 |
sarcasm_classifier = get_sarcasm_classifier()
|
|
|
|
|
|
|
|
|
136 |
result = sarcasm_classifier(text)[0]
|
137 |
is_sarcastic = result['label'] == "LABEL_1"
|
138 |
sarcasm_score = result['score'] if is_sarcastic else 1 - result['score']
|
@@ -141,11 +184,12 @@ def perform_sarcasm_detection(text):
|
|
141 |
st.error(f"Sarcasm detection failed: {str(e)}")
|
142 |
return False, 0.0
|
143 |
|
|
|
144 |
# Validate audio quality
|
145 |
def validate_audio(audio_path):
|
146 |
try:
|
147 |
sound = AudioSegment.from_file(audio_path)
|
148 |
-
if sound.dBFS < -
|
149 |
st.warning("Audio volume is too low. Please record or upload a louder audio.")
|
150 |
return False
|
151 |
if len(sound) < 1000: # Less than 1 second
|
@@ -156,36 +200,43 @@ def validate_audio(audio_path):
|
|
156 |
st.error("Invalid or corrupted audio file.")
|
157 |
return False
|
158 |
|
|
|
159 |
# Speech Recognition with Whisper
|
|
|
160 |
@st.cache_resource
|
161 |
def load_whisper_model():
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
165 |
|
166 |
def transcribe_audio(audio_path, show_alternative=False):
|
167 |
try:
|
168 |
st.write(f"Processing audio file: {audio_path}")
|
169 |
sound = AudioSegment.from_file(audio_path)
|
170 |
-
st.write(
|
171 |
-
|
|
|
172 |
# Convert to WAV format (16kHz, mono) for Whisper
|
173 |
temp_wav_path = os.path.join(tempfile.gettempdir(), "temp_converted.wav")
|
174 |
-
sound = sound.set_frame_rate(
|
175 |
sound = sound.set_channels(1)
|
176 |
sound.export(temp_wav_path, format="wav")
|
177 |
-
|
178 |
# Load Whisper model
|
179 |
model = load_whisper_model()
|
180 |
-
|
181 |
# Transcribe audio
|
182 |
result = model.transcribe(temp_wav_path, language="en")
|
183 |
main_text = result["text"].strip()
|
184 |
-
|
185 |
# Clean up
|
186 |
if os.path.exists(temp_wav_path):
|
187 |
os.remove(temp_wav_path)
|
188 |
-
|
189 |
# Whisper doesn't provide alternatives, so return empty list
|
190 |
if show_alternative:
|
191 |
return main_text, []
|
@@ -194,32 +245,39 @@ def transcribe_audio(audio_path, show_alternative=False):
|
|
194 |
st.error(f"Transcription failed: {str(e)}")
|
195 |
return "", [] if show_alternative else ""
|
196 |
|
|
|
197 |
# Function to handle uploaded audio files
|
198 |
def process_uploaded_audio(audio_file):
|
199 |
if not audio_file:
|
200 |
return None
|
201 |
-
|
202 |
try:
|
203 |
temp_dir = tempfile.gettempdir()
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
206 |
with open(temp_file_path, "wb") as f:
|
207 |
f.write(audio_file.getvalue())
|
208 |
-
|
209 |
if not validate_audio(temp_file_path):
|
210 |
return None
|
211 |
-
|
212 |
return temp_file_path
|
213 |
except Exception as e:
|
214 |
st.error(f"Error processing uploaded audio: {str(e)}")
|
215 |
return None
|
216 |
|
|
|
217 |
# Show model information
|
218 |
def show_model_info():
|
219 |
st.sidebar.header("π§ About the Models")
|
220 |
-
|
221 |
model_tabs = st.sidebar.tabs(["Emotion", "Sarcasm", "Speech"])
|
222 |
-
|
223 |
with model_tabs[0]:
|
224 |
st.markdown("""
|
225 |
*Emotion Model*: SamLowe/roberta-base-go_emotions
|
@@ -228,7 +286,7 @@ def show_model_info():
|
|
228 |
- Micro-F1: 0.46
|
229 |
[π Model Hub](https://huggingface.co/SamLowe/roberta-base-go_emotions)
|
230 |
""")
|
231 |
-
|
232 |
with model_tabs[1]:
|
233 |
st.markdown("""
|
234 |
*Sarcasm Model*: cardiffnlp/twitter-roberta-base-irony
|
@@ -237,7 +295,7 @@ def show_model_info():
|
|
237 |
- F1-score: 0.705
|
238 |
[π Model Hub](https://huggingface.co/cardiffnlp/twitter-roberta-base-irony)
|
239 |
""")
|
240 |
-
|
241 |
with model_tabs[2]:
|
242 |
st.markdown("""
|
243 |
*Speech Recognition*: OpenAI Whisper (large-v3)
|
@@ -249,8 +307,10 @@ def show_model_info():
|
|
249 |
[π Model Details](https://github.com/openai/whisper)
|
250 |
""")
|
251 |
|
|
|
252 |
# Custom audio recorder using HTML/JS
|
253 |
def custom_audio_recorder():
|
|
|
254 |
audio_recorder_html = """
|
255 |
<script>
|
256 |
var audioRecorder = {
|
@@ -267,11 +327,11 @@ def custom_audio_recorder():
|
|
267 |
audioRecorder.streamBeingCaptured = stream;
|
268 |
audioRecorder.mediaRecorder = new MediaRecorder(stream);
|
269 |
audioRecorder.audioBlobs = [];
|
270 |
-
|
271 |
audioRecorder.mediaRecorder.addEventListener("dataavailable", event => {
|
272 |
audioRecorder.audioBlobs.push(event.data);
|
273 |
});
|
274 |
-
|
275 |
audioRecorder.mediaRecorder.start();
|
276 |
});
|
277 |
}
|
@@ -279,14 +339,14 @@ def custom_audio_recorder():
|
|
279 |
stop: function() {
|
280 |
return new Promise(resolve => {
|
281 |
let mimeType = audioRecorder.mediaRecorder.mimeType;
|
282 |
-
|
283 |
audioRecorder.mediaRecorder.addEventListener("stop", () => {
|
284 |
let audioBlob = new Blob(audioRecorder.audioBlobs, { type: mimeType });
|
285 |
resolve(audioBlob);
|
286 |
});
|
287 |
-
|
288 |
audioRecorder.mediaRecorder.stop();
|
289 |
-
|
290 |
audioRecorder.stopStream();
|
291 |
audioRecorder.resetRecordingProperties();
|
292 |
});
|
@@ -304,7 +364,7 @@ def custom_audio_recorder():
|
|
304 |
var recordButton = document.getElementById('record-button');
|
305 |
var audioElement = document.getElementById('audio-playback');
|
306 |
var audioData = document.getElementById('audio-data');
|
307 |
-
|
308 |
function toggleRecording() {
|
309 |
if (!isRecording) {
|
310 |
audioRecorder.start()
|
@@ -321,7 +381,7 @@ def custom_audio_recorder():
|
|
321 |
.then(audioBlob => {
|
322 |
const audioUrl = URL.createObjectURL(audioBlob);
|
323 |
audioElement.src = audioUrl;
|
324 |
-
|
325 |
const reader = new FileReader();
|
326 |
reader.readAsDataURL(audioBlob);
|
327 |
reader.onloadend = function() {
|
@@ -330,7 +390,7 @@ def custom_audio_recorder():
|
|
330 |
const streamlitMessage = {type: "streamlit:setComponentValue", value: base64data};
|
331 |
window.parent.postMessage(streamlitMessage, "*");
|
332 |
}
|
333 |
-
|
334 |
isRecording = false;
|
335 |
recordButton.textContent = 'Start Recording';
|
336 |
recordButton.classList.remove('recording');
|
@@ -341,7 +401,7 @@ def custom_audio_recorder():
|
|
341 |
recordButton = document.getElementById('record-button');
|
342 |
audioElement = document.getElementById('audio-playback');
|
343 |
audioData = document.getElementById('audio-data');
|
344 |
-
|
345 |
recordButton.addEventListener('click', toggleRecording);
|
346 |
});
|
347 |
</script>
|
@@ -377,18 +437,20 @@ def custom_audio_recorder():
|
|
377 |
}
|
378 |
</style>
|
379 |
"""
|
380 |
-
|
381 |
return components.html(audio_recorder_html, height=150)
|
382 |
|
|
|
383 |
# Function to display analysis results
|
384 |
def display_analysis_results(transcribed_text):
|
385 |
# Fix 5: Add debugging to track what's happening
|
386 |
st.session_state.debug_info = st.session_state.get('debug_info', [])
|
387 |
st.session_state.debug_info.append(f"Processing text: {transcribed_text[:50]}...")
|
388 |
-
|
|
|
389 |
emotions_dict, top_emotion, emotion_map, sentiment = perform_emotion_detection(transcribed_text)
|
390 |
is_sarcastic, sarcasm_score = perform_sarcasm_detection(transcribed_text)
|
391 |
-
|
392 |
# Add results to debug info
|
393 |
st.session_state.debug_info.append(f"Top emotion: {top_emotion}, Sentiment: {sentiment}")
|
394 |
st.session_state.debug_info.append(f"Sarcasm: {is_sarcastic}, Score: {sarcasm_score:.3f}")
|
@@ -397,7 +459,7 @@ def display_analysis_results(transcribed_text):
|
|
397 |
st.text_area("Text", transcribed_text, height=150, disabled=True, help="The audio converted to text.")
|
398 |
|
399 |
confidence_score = min(0.95, max(0.70, len(transcribed_text.split()) / 50))
|
400 |
-
st.caption(f"
|
401 |
|
402 |
st.header("Analysis Results")
|
403 |
col1, col2 = st.columns([1, 2])
|
@@ -417,13 +479,14 @@ def display_analysis_results(transcribed_text):
|
|
417 |
with col2:
|
418 |
st.subheader("Emotions")
|
419 |
if emotions_dict:
|
420 |
-
st.markdown(
|
|
|
421 |
sorted_emotions = sorted(emotions_dict.items(), key=lambda x: x[1], reverse=True)
|
422 |
top_emotions = sorted_emotions[:8]
|
423 |
emotions = [e[0] for e in top_emotions]
|
424 |
scores = [e[1] for e in top_emotions]
|
425 |
-
fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'},
|
426 |
-
title="Top Emotions Distribution", color=emotions,
|
427 |
color_discrete_sequence=px.colors.qualitative.Bold)
|
428 |
fig.update_layout(yaxis_range=[0, 1], showlegend=False, title_font_size=14)
|
429 |
st.plotly_chart(fig, use_container_width=True)
|
@@ -434,7 +497,7 @@ def display_analysis_results(transcribed_text):
|
|
434 |
with st.expander("Debug Information", expanded=False):
|
435 |
st.write("Debugging information for troubleshooting:")
|
436 |
for i, debug_line in enumerate(st.session_state.debug_info[-10:]):
|
437 |
-
st.text(f"{i+1}. {debug_line}")
|
438 |
if emotions_dict:
|
439 |
st.write("Raw emotion scores:")
|
440 |
for emotion, score in sorted(emotions_dict.items(), key=lambda x: x[1], reverse=True):
|
@@ -455,101 +518,104 @@ def display_analysis_results(transcribed_text):
|
|
455 |
- Speech patterns
|
456 |
""")
|
457 |
|
|
|
458 |
# Process base64 audio data
|
459 |
def process_base64_audio(base64_data):
|
460 |
try:
|
461 |
base64_binary = base64_data.split(',')[1]
|
462 |
binary_data = base64.b64decode(base64_binary)
|
463 |
-
|
464 |
temp_dir = tempfile.gettempdir()
|
465 |
temp_file_path = os.path.join(temp_dir, f"recording_{int(time.time())}.wav")
|
466 |
-
|
467 |
with open(temp_file_path, "wb") as f:
|
468 |
f.write(binary_data)
|
469 |
-
|
470 |
if not validate_audio(temp_file_path):
|
471 |
return None
|
472 |
-
|
473 |
return temp_file_path
|
474 |
except Exception as e:
|
475 |
st.error(f"Error processing audio data: {str(e)}")
|
476 |
return None
|
477 |
|
|
|
478 |
# Main App Logic
|
479 |
def main():
|
480 |
# Fix 7: Initialize session state for debugging
|
481 |
if 'debug_info' not in st.session_state:
|
482 |
st.session_state.debug_info = []
|
483 |
-
|
484 |
tab1, tab2 = st.tabs(["π Upload Audio", "π Record Audio"])
|
485 |
-
|
486 |
with tab1:
|
487 |
st.header("Upload an Audio File")
|
488 |
-
audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"],
|
489 |
-
|
490 |
-
|
491 |
if audio_file:
|
492 |
st.audio(audio_file.getvalue())
|
493 |
st.caption("π§ Uploaded Audio Playback")
|
494 |
-
|
495 |
upload_button = st.button("Analyze Upload", key="analyze_upload")
|
496 |
-
|
497 |
if upload_button:
|
498 |
with st.spinner('Analyzing audio with advanced precision...'):
|
499 |
temp_audio_path = process_uploaded_audio(audio_file)
|
500 |
if temp_audio_path:
|
501 |
main_text, alternatives = transcribe_audio(temp_audio_path, show_alternative=True)
|
502 |
-
|
503 |
if main_text:
|
504 |
if alternatives:
|
505 |
with st.expander("Alternative transcriptions detected", expanded=False):
|
506 |
for i, alt in enumerate(alternatives[:3], 1):
|
507 |
st.write(f"{i}. {alt}")
|
508 |
-
|
509 |
display_analysis_results(main_text)
|
510 |
else:
|
511 |
st.error("Could not transcribe the audio. Please try again with clearer audio.")
|
512 |
-
|
513 |
if os.path.exists(temp_audio_path):
|
514 |
os.remove(temp_audio_path)
|
515 |
-
|
516 |
with tab2:
|
517 |
st.header("Record Your Voice")
|
518 |
st.write("Use the recorder below to analyze your speech in real-time.")
|
519 |
-
|
520 |
st.subheader("Browser-Based Recorder")
|
521 |
st.write("Click the button below to start/stop recording.")
|
522 |
-
|
523 |
audio_data = custom_audio_recorder()
|
524 |
-
|
525 |
if audio_data:
|
526 |
analyze_rec_button = st.button("Analyze Recording", key="analyze_rec")
|
527 |
-
|
528 |
if analyze_rec_button:
|
529 |
with st.spinner("Processing your recording..."):
|
530 |
temp_audio_path = process_base64_audio(audio_data)
|
531 |
-
|
532 |
if temp_audio_path:
|
533 |
transcribed_text = transcribe_audio(temp_audio_path)
|
534 |
-
|
535 |
if transcribed_text:
|
536 |
display_analysis_results(transcribed_text)
|
537 |
else:
|
538 |
st.error("Could not transcribe the audio. Please try speaking more clearly.")
|
539 |
-
|
540 |
if os.path.exists(temp_audio_path):
|
541 |
os.remove(temp_audio_path)
|
542 |
-
|
543 |
st.subheader("Manual Text Input")
|
544 |
st.write("If recording doesn't work, you can type your text here:")
|
545 |
-
|
546 |
manual_text = st.text_area("Enter text to analyze:", placeholder="Type what you want to analyze...")
|
547 |
analyze_text_button = st.button("Analyze Text", key="analyze_manual")
|
548 |
-
|
549 |
if analyze_text_button and manual_text:
|
550 |
display_analysis_results(manual_text)
|
551 |
|
552 |
show_model_info()
|
553 |
|
|
|
554 |
if __name__ == "__main__":
|
555 |
main()
|
|
|
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
import tempfile
|
|
|
31 |
st.title("π Voice Based Sentiment Analysis")
|
32 |
st.write("Detect emotions, sentiment, and sarcasm from your voice with state-of-the-art accuracy using OpenAI Whisper.")
|
33 |
|
34 |
+
|
35 |
+
# Emotion Detection Function
|
36 |
# Emotion Detection Function
|
37 |
@st.cache_resource
|
38 |
def get_emotion_classifier():
|
39 |
+
try:
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained("SamLowe/roberta-base-go_emotions", use_fast=True)
|
41 |
+
model = AutoModelForSequenceClassification.from_pretrained("SamLowe/roberta-base-go_emotions")
|
42 |
+
model = model.to(device)
|
43 |
+
|
44 |
+
# Changed from device=-1 if device.type == "cpu" else 0
|
45 |
+
# to ensure proper device selection
|
46 |
+
classifier = pipeline("text-classification",
|
47 |
+
model=model,
|
48 |
+
tokenizer=tokenizer,
|
49 |
+
top_k=None,
|
50 |
+
device=0 if torch.cuda.is_available() else -1)
|
51 |
+
|
52 |
+
# Add a verification test to make sure the model is working
|
53 |
+
test_result = classifier("I am happy today")
|
54 |
+
print(f"Emotion classifier test: {test_result}")
|
55 |
+
|
56 |
+
return classifier
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Error loading emotion model: {str(e)}")
|
59 |
+
st.error(f"Failed to load emotion model. Please check logs.")
|
60 |
+
# Return a basic fallback that won't crash
|
61 |
+
return None
|
62 |
+
|
63 |
|
64 |
def perform_emotion_detection(text):
|
65 |
try:
|
66 |
if not text or len(text.strip()) < 3:
|
67 |
return {}, "neutral", {}, "NEUTRAL"
|
68 |
+
|
69 |
emotion_classifier = get_emotion_classifier()
|
70 |
emotion_results = emotion_classifier(text)[0]
|
71 |
+
|
72 |
emotion_map = {
|
73 |
+
"admiration": "π€©", "amusement": "π", "anger": "π‘", "annoyance": "π",
|
74 |
"approval": "π", "caring": "π€", "confusion": "π", "curiosity": "π§",
|
75 |
"desire": "π", "disappointment": "π", "disapproval": "π", "disgust": "π€’",
|
76 |
"embarrassment": "π³", "excitement": "π€©", "fear": "π¨", "gratitude": "π",
|
|
|
78 |
"optimism": "π", "pride": "π", "realization": "π‘", "relief": "π",
|
79 |
"remorse": "π", "sadness": "π", "surprise": "π²", "neutral": "π"
|
80 |
}
|
81 |
+
|
82 |
+
positive_emotions = ["admiration", "amusement", "approval", "caring", "desire",
|
83 |
+
"excitement", "gratitude", "joy", "love", "optimism", "pride", "relief"]
|
84 |
negative_emotions = ["anger", "annoyance", "disappointment", "disapproval", "disgust",
|
85 |
+
"embarrassment", "fear", "grief", "nervousness", "remorse", "sadness"]
|
86 |
neutral_emotions = ["confusion", "curiosity", "realization", "surprise", "neutral"]
|
87 |
+
|
88 |
# Fix 1: Create a clean emotions dictionary from results
|
89 |
emotions_dict = {}
|
90 |
for result in emotion_results:
|
91 |
emotions_dict[result['label']] = result['score']
|
92 |
+
|
93 |
# Fix 2: Filter out very low scores (below threshold)
|
94 |
+
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.02}
|
95 |
+
|
96 |
# If filtered dictionary is empty, fall back to original
|
97 |
if not filtered_emotions:
|
98 |
filtered_emotions = emotions_dict
|
99 |
+
|
100 |
# Fix 3: Make sure we properly find the top emotion
|
101 |
top_emotion = max(filtered_emotions, key=filtered_emotions.get)
|
102 |
top_score = filtered_emotions[top_emotion]
|
103 |
+
|
104 |
# Fix 4: More robust sentiment assignment
|
105 |
if top_emotion in positive_emotions:
|
106 |
sentiment = "POSITIVE"
|
|
|
109 |
else:
|
110 |
# If the top emotion is neutral but there are strong competing emotions, use them
|
111 |
competing_emotions = sorted(filtered_emotions.items(), key=lambda x: x[1], reverse=True)[:3]
|
112 |
+
|
113 |
# Check if there's a close second non-neutral emotion
|
114 |
if len(competing_emotions) > 1:
|
115 |
+
if (competing_emotions[0][0] in neutral_emotions and
|
116 |
+
competing_emotions[1][0] not in neutral_emotions and
|
117 |
+
competing_emotions[1][1] > 0.5 * competing_emotions[0][1]):
|
118 |
# Use the second strongest emotion instead
|
119 |
top_emotion = competing_emotions[1][0]
|
120 |
if top_emotion in positive_emotions:
|
|
|
127 |
sentiment = "NEUTRAL"
|
128 |
else:
|
129 |
sentiment = "NEUTRAL"
|
130 |
+
|
131 |
# Log for debugging
|
132 |
print(f"Text: {text[:50]}...")
|
133 |
print(f"Top 3 emotions: {sorted(filtered_emotions.items(), key=lambda x: x[1], reverse=True)[:3]}")
|
134 |
print(f"Selected top emotion: {top_emotion} ({filtered_emotions.get(top_emotion, 0):.3f})")
|
135 |
print(f"Sentiment determined: {sentiment}")
|
136 |
+
|
137 |
+
print(f"All emotions detected: {filtered_emotions}")
|
138 |
+
print(f"Filtered emotions: {filtered_emotions}")
|
139 |
+
print(f"Emotion classification threshold: 0.02")
|
140 |
+
|
141 |
return emotions_dict, top_emotion, emotion_map, sentiment
|
142 |
except Exception as e:
|
143 |
st.error(f"Emotion detection failed: {str(e)}")
|
144 |
print(f"Exception in emotion detection: {str(e)}")
|
145 |
return {}, "neutral", {}, "NEUTRAL"
|
146 |
|
147 |
+
|
148 |
# Sarcasm Detection Function
|
149 |
@st.cache_resource
|
150 |
def get_sarcasm_classifier():
|
151 |
+
try:
|
152 |
+
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony", use_fast=True)
|
153 |
+
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
|
154 |
+
model = model.to(device)
|
155 |
+
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer,
|
156 |
+
device=0 if torch.cuda.is_available() else -1)
|
157 |
+
|
158 |
+
# Add a verification test to ensure the model is working
|
159 |
+
test_result = classifier("This is totally amazing")
|
160 |
+
print(f"Sarcasm classifier test: {test_result}")
|
161 |
+
|
162 |
+
return classifier
|
163 |
+
except Exception as e:
|
164 |
+
print(f"Error loading sarcasm model: {str(e)}")
|
165 |
+
st.error(f"Failed to load sarcasm model. Please check logs.")
|
166 |
+
return None
|
167 |
+
|
168 |
|
169 |
def perform_sarcasm_detection(text):
|
170 |
try:
|
171 |
if not text or len(text.strip()) < 3:
|
172 |
return False, 0.0
|
173 |
+
|
174 |
sarcasm_classifier = get_sarcasm_classifier()
|
175 |
+
if sarcasm_classifier is None:
|
176 |
+
st.error("Sarcasm classifier not available.")
|
177 |
+
return False, 0.0
|
178 |
+
|
179 |
result = sarcasm_classifier(text)[0]
|
180 |
is_sarcastic = result['label'] == "LABEL_1"
|
181 |
sarcasm_score = result['score'] if is_sarcastic else 1 - result['score']
|
|
|
184 |
st.error(f"Sarcasm detection failed: {str(e)}")
|
185 |
return False, 0.0
|
186 |
|
187 |
+
|
188 |
# Validate audio quality
|
189 |
def validate_audio(audio_path):
|
190 |
try:
|
191 |
sound = AudioSegment.from_file(audio_path)
|
192 |
+
if sound.dBFS < -55:
|
193 |
st.warning("Audio volume is too low. Please record or upload a louder audio.")
|
194 |
return False
|
195 |
if len(sound) < 1000: # Less than 1 second
|
|
|
200 |
st.error("Invalid or corrupted audio file.")
|
201 |
return False
|
202 |
|
203 |
+
|
204 |
# Speech Recognition with Whisper
|
205 |
+
# @st.cache_resource
|
206 |
@st.cache_resource
|
207 |
def load_whisper_model():
|
208 |
+
try:
|
209 |
+
model = whisper.load_model("large-v3")
|
210 |
+
return model
|
211 |
+
except Exception as e:
|
212 |
+
print(f"Error loading Whisper model: {str(e)}")
|
213 |
+
st.error(f"Failed to load Whisper model. Please check logs.")
|
214 |
+
return None
|
215 |
|
216 |
def transcribe_audio(audio_path, show_alternative=False):
|
217 |
try:
|
218 |
st.write(f"Processing audio file: {audio_path}")
|
219 |
sound = AudioSegment.from_file(audio_path)
|
220 |
+
st.write(
|
221 |
+
f"Audio duration: {len(sound) / 1000:.2f}s, Sample rate: {sound.frame_rate}, Channels: {sound.channels}")
|
222 |
+
|
223 |
# Convert to WAV format (16kHz, mono) for Whisper
|
224 |
temp_wav_path = os.path.join(tempfile.gettempdir(), "temp_converted.wav")
|
225 |
+
sound = sound.set_frame_rate(22050)
|
226 |
sound = sound.set_channels(1)
|
227 |
sound.export(temp_wav_path, format="wav")
|
228 |
+
|
229 |
# Load Whisper model
|
230 |
model = load_whisper_model()
|
231 |
+
|
232 |
# Transcribe audio
|
233 |
result = model.transcribe(temp_wav_path, language="en")
|
234 |
main_text = result["text"].strip()
|
235 |
+
|
236 |
# Clean up
|
237 |
if os.path.exists(temp_wav_path):
|
238 |
os.remove(temp_wav_path)
|
239 |
+
|
240 |
# Whisper doesn't provide alternatives, so return empty list
|
241 |
if show_alternative:
|
242 |
return main_text, []
|
|
|
245 |
st.error(f"Transcription failed: {str(e)}")
|
246 |
return "", [] if show_alternative else ""
|
247 |
|
248 |
+
|
249 |
# Function to handle uploaded audio files
|
250 |
def process_uploaded_audio(audio_file):
|
251 |
if not audio_file:
|
252 |
return None
|
253 |
+
|
254 |
try:
|
255 |
temp_dir = tempfile.gettempdir()
|
256 |
+
|
257 |
+
ext = audio_file.name.split('.')[-1].lower()
|
258 |
+
if ext not in ['wav', 'mp3', 'ogg']:
|
259 |
+
st.error("Unsupported audio format. Please upload WAV, MP3, or OGG.")
|
260 |
+
return None
|
261 |
+
temp_file_path = os.path.join(temp_dir, f"uploaded_audio_{int(time.time())}.{ext}")
|
262 |
+
|
263 |
with open(temp_file_path, "wb") as f:
|
264 |
f.write(audio_file.getvalue())
|
265 |
+
|
266 |
if not validate_audio(temp_file_path):
|
267 |
return None
|
268 |
+
|
269 |
return temp_file_path
|
270 |
except Exception as e:
|
271 |
st.error(f"Error processing uploaded audio: {str(e)}")
|
272 |
return None
|
273 |
|
274 |
+
|
275 |
# Show model information
|
276 |
def show_model_info():
|
277 |
st.sidebar.header("π§ About the Models")
|
278 |
+
|
279 |
model_tabs = st.sidebar.tabs(["Emotion", "Sarcasm", "Speech"])
|
280 |
+
|
281 |
with model_tabs[0]:
|
282 |
st.markdown("""
|
283 |
*Emotion Model*: SamLowe/roberta-base-go_emotions
|
|
|
286 |
- Micro-F1: 0.46
|
287 |
[π Model Hub](https://huggingface.co/SamLowe/roberta-base-go_emotions)
|
288 |
""")
|
289 |
+
|
290 |
with model_tabs[1]:
|
291 |
st.markdown("""
|
292 |
*Sarcasm Model*: cardiffnlp/twitter-roberta-base-irony
|
|
|
295 |
- F1-score: 0.705
|
296 |
[π Model Hub](https://huggingface.co/cardiffnlp/twitter-roberta-base-irony)
|
297 |
""")
|
298 |
+
|
299 |
with model_tabs[2]:
|
300 |
st.markdown("""
|
301 |
*Speech Recognition*: OpenAI Whisper (large-v3)
|
|
|
307 |
[π Model Details](https://github.com/openai/whisper)
|
308 |
""")
|
309 |
|
310 |
+
|
311 |
# Custom audio recorder using HTML/JS
|
312 |
def custom_audio_recorder():
|
313 |
+
st.warning("Browser-based recording requires microphone access and a modern browser. If recording fails, try uploading an audio file instead.")
|
314 |
audio_recorder_html = """
|
315 |
<script>
|
316 |
var audioRecorder = {
|
|
|
327 |
audioRecorder.streamBeingCaptured = stream;
|
328 |
audioRecorder.mediaRecorder = new MediaRecorder(stream);
|
329 |
audioRecorder.audioBlobs = [];
|
330 |
+
|
331 |
audioRecorder.mediaRecorder.addEventListener("dataavailable", event => {
|
332 |
audioRecorder.audioBlobs.push(event.data);
|
333 |
});
|
334 |
+
|
335 |
audioRecorder.mediaRecorder.start();
|
336 |
});
|
337 |
}
|
|
|
339 |
stop: function() {
|
340 |
return new Promise(resolve => {
|
341 |
let mimeType = audioRecorder.mediaRecorder.mimeType;
|
342 |
+
|
343 |
audioRecorder.mediaRecorder.addEventListener("stop", () => {
|
344 |
let audioBlob = new Blob(audioRecorder.audioBlobs, { type: mimeType });
|
345 |
resolve(audioBlob);
|
346 |
});
|
347 |
+
|
348 |
audioRecorder.mediaRecorder.stop();
|
349 |
+
|
350 |
audioRecorder.stopStream();
|
351 |
audioRecorder.resetRecordingProperties();
|
352 |
});
|
|
|
364 |
var recordButton = document.getElementById('record-button');
|
365 |
var audioElement = document.getElementById('audio-playback');
|
366 |
var audioData = document.getElementById('audio-data');
|
367 |
+
|
368 |
function toggleRecording() {
|
369 |
if (!isRecording) {
|
370 |
audioRecorder.start()
|
|
|
381 |
.then(audioBlob => {
|
382 |
const audioUrl = URL.createObjectURL(audioBlob);
|
383 |
audioElement.src = audioUrl;
|
384 |
+
|
385 |
const reader = new FileReader();
|
386 |
reader.readAsDataURL(audioBlob);
|
387 |
reader.onloadend = function() {
|
|
|
390 |
const streamlitMessage = {type: "streamlit:setComponentValue", value: base64data};
|
391 |
window.parent.postMessage(streamlitMessage, "*");
|
392 |
}
|
393 |
+
|
394 |
isRecording = false;
|
395 |
recordButton.textContent = 'Start Recording';
|
396 |
recordButton.classList.remove('recording');
|
|
|
401 |
recordButton = document.getElementById('record-button');
|
402 |
audioElement = document.getElementById('audio-playback');
|
403 |
audioData = document.getElementById('audio-data');
|
404 |
+
|
405 |
recordButton.addEventListener('click', toggleRecording);
|
406 |
});
|
407 |
</script>
|
|
|
437 |
}
|
438 |
</style>
|
439 |
"""
|
440 |
+
|
441 |
return components.html(audio_recorder_html, height=150)
|
442 |
|
443 |
+
|
444 |
# Function to display analysis results
|
445 |
def display_analysis_results(transcribed_text):
|
446 |
# Fix 5: Add debugging to track what's happening
|
447 |
st.session_state.debug_info = st.session_state.get('debug_info', [])
|
448 |
st.session_state.debug_info.append(f"Processing text: {transcribed_text[:50]}...")
|
449 |
+
st.session_state.debug_info = st.session_state.debug_info[-100:] # Keep last 100 entries
|
450 |
+
|
451 |
emotions_dict, top_emotion, emotion_map, sentiment = perform_emotion_detection(transcribed_text)
|
452 |
is_sarcastic, sarcasm_score = perform_sarcasm_detection(transcribed_text)
|
453 |
+
|
454 |
# Add results to debug info
|
455 |
st.session_state.debug_info.append(f"Top emotion: {top_emotion}, Sentiment: {sentiment}")
|
456 |
st.session_state.debug_info.append(f"Sarcasm: {is_sarcastic}, Score: {sarcasm_score:.3f}")
|
|
|
459 |
st.text_area("Text", transcribed_text, height=150, disabled=True, help="The audio converted to text.")
|
460 |
|
461 |
confidence_score = min(0.95, max(0.70, len(transcribed_text.split()) / 50))
|
462 |
+
st.caption(f"Estimated transcription confidence: {confidence_score:.2f} (based on text length)")
|
463 |
|
464 |
st.header("Analysis Results")
|
465 |
col1, col2 = st.columns([1, 2])
|
|
|
479 |
with col2:
|
480 |
st.subheader("Emotions")
|
481 |
if emotions_dict:
|
482 |
+
st.markdown(
|
483 |
+
f"*Dominant:* {emotion_map.get(top_emotion, 'β')} {top_emotion.capitalize()} (Score: {emotions_dict[top_emotion]:.3f})")
|
484 |
sorted_emotions = sorted(emotions_dict.items(), key=lambda x: x[1], reverse=True)
|
485 |
top_emotions = sorted_emotions[:8]
|
486 |
emotions = [e[0] for e in top_emotions]
|
487 |
scores = [e[1] for e in top_emotions]
|
488 |
+
fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'},
|
489 |
+
title="Top Emotions Distribution", color=emotions,
|
490 |
color_discrete_sequence=px.colors.qualitative.Bold)
|
491 |
fig.update_layout(yaxis_range=[0, 1], showlegend=False, title_font_size=14)
|
492 |
st.plotly_chart(fig, use_container_width=True)
|
|
|
497 |
with st.expander("Debug Information", expanded=False):
|
498 |
st.write("Debugging information for troubleshooting:")
|
499 |
for i, debug_line in enumerate(st.session_state.debug_info[-10:]):
|
500 |
+
st.text(f"{i + 1}. {debug_line}")
|
501 |
if emotions_dict:
|
502 |
st.write("Raw emotion scores:")
|
503 |
for emotion, score in sorted(emotions_dict.items(), key=lambda x: x[1], reverse=True):
|
|
|
518 |
- Speech patterns
|
519 |
""")
|
520 |
|
521 |
+
|
522 |
# Process base64 audio data
|
523 |
def process_base64_audio(base64_data):
|
524 |
try:
|
525 |
base64_binary = base64_data.split(',')[1]
|
526 |
binary_data = base64.b64decode(base64_binary)
|
527 |
+
|
528 |
temp_dir = tempfile.gettempdir()
|
529 |
temp_file_path = os.path.join(temp_dir, f"recording_{int(time.time())}.wav")
|
530 |
+
|
531 |
with open(temp_file_path, "wb") as f:
|
532 |
f.write(binary_data)
|
533 |
+
|
534 |
if not validate_audio(temp_file_path):
|
535 |
return None
|
536 |
+
|
537 |
return temp_file_path
|
538 |
except Exception as e:
|
539 |
st.error(f"Error processing audio data: {str(e)}")
|
540 |
return None
|
541 |
|
542 |
+
|
543 |
# Main App Logic
|
544 |
def main():
|
545 |
# Fix 7: Initialize session state for debugging
|
546 |
if 'debug_info' not in st.session_state:
|
547 |
st.session_state.debug_info = []
|
548 |
+
|
549 |
tab1, tab2 = st.tabs(["π Upload Audio", "π Record Audio"])
|
550 |
+
|
551 |
with tab1:
|
552 |
st.header("Upload an Audio File")
|
553 |
+
audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"],
|
554 |
+
help="Upload an audio file for analysis")
|
555 |
+
|
556 |
if audio_file:
|
557 |
st.audio(audio_file.getvalue())
|
558 |
st.caption("π§ Uploaded Audio Playback")
|
559 |
+
|
560 |
upload_button = st.button("Analyze Upload", key="analyze_upload")
|
561 |
+
|
562 |
if upload_button:
|
563 |
with st.spinner('Analyzing audio with advanced precision...'):
|
564 |
temp_audio_path = process_uploaded_audio(audio_file)
|
565 |
if temp_audio_path:
|
566 |
main_text, alternatives = transcribe_audio(temp_audio_path, show_alternative=True)
|
567 |
+
|
568 |
if main_text:
|
569 |
if alternatives:
|
570 |
with st.expander("Alternative transcriptions detected", expanded=False):
|
571 |
for i, alt in enumerate(alternatives[:3], 1):
|
572 |
st.write(f"{i}. {alt}")
|
573 |
+
|
574 |
display_analysis_results(main_text)
|
575 |
else:
|
576 |
st.error("Could not transcribe the audio. Please try again with clearer audio.")
|
577 |
+
|
578 |
if os.path.exists(temp_audio_path):
|
579 |
os.remove(temp_audio_path)
|
580 |
+
|
581 |
with tab2:
|
582 |
st.header("Record Your Voice")
|
583 |
st.write("Use the recorder below to analyze your speech in real-time.")
|
584 |
+
|
585 |
st.subheader("Browser-Based Recorder")
|
586 |
st.write("Click the button below to start/stop recording.")
|
587 |
+
|
588 |
audio_data = custom_audio_recorder()
|
589 |
+
|
590 |
if audio_data:
|
591 |
analyze_rec_button = st.button("Analyze Recording", key="analyze_rec")
|
592 |
+
|
593 |
if analyze_rec_button:
|
594 |
with st.spinner("Processing your recording..."):
|
595 |
temp_audio_path = process_base64_audio(audio_data)
|
596 |
+
|
597 |
if temp_audio_path:
|
598 |
transcribed_text = transcribe_audio(temp_audio_path)
|
599 |
+
|
600 |
if transcribed_text:
|
601 |
display_analysis_results(transcribed_text)
|
602 |
else:
|
603 |
st.error("Could not transcribe the audio. Please try speaking more clearly.")
|
604 |
+
|
605 |
if os.path.exists(temp_audio_path):
|
606 |
os.remove(temp_audio_path)
|
607 |
+
|
608 |
st.subheader("Manual Text Input")
|
609 |
st.write("If recording doesn't work, you can type your text here:")
|
610 |
+
|
611 |
manual_text = st.text_area("Enter text to analyze:", placeholder="Type what you want to analyze...")
|
612 |
analyze_text_button = st.button("Analyze Text", key="analyze_manual")
|
613 |
+
|
614 |
if analyze_text_button and manual_text:
|
615 |
display_analysis_results(manual_text)
|
616 |
|
617 |
show_model_info()
|
618 |
|
619 |
+
|
620 |
if __name__ == "__main__":
|
621 |
main()
|