MaroofTechSorcerer commited on
Commit
42d828e
Β·
verified Β·
1 Parent(s): b2e2b24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -493
app.py CHANGED
@@ -2,170 +2,83 @@ import os
2
  import streamlit as st
3
  import tempfile
4
  import torch
 
5
  import transformers
6
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
7
  import plotly.express as px
8
  import logging
9
  import warnings
10
  import whisper
11
- from pydub import AudioSegment
12
- import time
13
  import base64
14
  import io
 
 
15
  import streamlit.components.v1 as components
16
 
17
- # Suppress warnings for a clean console
18
- logging.getLogger("torch").setLevel(logging.CRITICAL)
19
- logging.getLogger("transformers").setLevel(logging.CRITICAL)
20
  warnings.filterwarnings("ignore")
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
 
23
- # Check if CUDA is available, otherwise use CPU
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- print(f"Using device: {device}")
26
 
27
- # Set Streamlit app layout
28
- st.set_page_config(layout="wide", page_title="Voice Based Sentiment Analysis")
 
 
29
 
30
- # Interface design
31
- st.title("πŸŽ™ Voice Based Sentiment Analysis")
32
- st.write("Detect emotions, sentiment, and sarcasm from your voice with state-of-the-art accuracy using OpenAI Whisper.")
33
-
34
- # Emotion Detection Function
35
  @st.cache_resource
36
- def get_emotion_classifier():
37
- try:
38
- tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion", use_fast=True)
39
- model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
40
- model = model.to(device)
41
-
42
- classifier = pipeline("text-classification",
43
- model=model,
44
- tokenizer=tokenizer,
45
- top_k=None,
46
- device=0 if torch.cuda.is_available() else -1)
47
-
48
- # Add a verification test to make sure the model is working
49
- test_result = classifier("I am happy today")
50
- print(f"Emotion classifier test: {test_result}")
51
-
52
- return classifier
53
- except Exception as e:
54
- print(f"Error loading emotion model: {str(e)}")
55
- st.error(f"Failed to load emotion model. Please check logs.")
56
- return None
57
-
58
- def perform_emotion_detection(text):
 
59
  try:
60
- if not text or len(text.strip()) < 3:
61
- return {}, "neutral", {}, "NEUTRAL"
62
-
63
- emotion_classifier = get_emotion_classifier()
64
- if emotion_classifier is None:
65
- st.error("Emotion classifier not available.")
66
- return {}, "neutral", {}, "NEUTRAL"
67
-
68
- emotion_results = emotion_classifier(text)
69
- print(f"Raw emotion classifier output: {emotion_results}")
70
- if not emotion_results or not isinstance(emotion_results, list) or not emotion_results[0]:
71
- st.error("Emotion classifier returned invalid or empty results.")
72
- return {}, "neutral", {}, "NEUTRAL"
73
-
74
- # Access the first inner list, which contains the emotion dictionaries
75
- emotion_results = emotion_results[0]
76
- emotion_map = {
77
- "joy": "😊", "anger": "😑", "disgust": "🀒", "fear": "😨",
78
- "sadness": "😭", "surprise": "😲"
79
- }
80
- positive_emotions = ["joy"]
81
- negative_emotions = ["anger", "disgust", "fear", "sadness"]
82
- neutral_emotions = ["surprise"]
83
-
84
- emotions_dict = {}
85
- for result in emotion_results:
86
- if isinstance(result, dict) and 'label' in result and 'score' in result:
87
- emotions_dict[result['label']] = result['score']
88
- else:
89
- print(f"Invalid result format: {result}")
90
-
91
- if not emotions_dict:
92
- st.error("No valid emotions detected.")
93
- return {}, "neutral", {}, "NEUTRAL"
94
-
95
  filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
96
-
97
- if not filtered_emotions:
98
- filtered_emotions = emotions_dict
99
-
100
  top_emotion = max(filtered_emotions, key=filtered_emotions.get)
101
- top_score = filtered_emotions[top_emotion]
102
-
103
- if top_emotion in positive_emotions:
104
- sentiment = "POSITIVE"
105
- elif top_emotion in negative_emotions:
106
- sentiment = "NEGATIVE"
107
- else:
108
- competing_emotions = sorted(filtered_emotions.items(), key=lambda x: x[1], reverse=True)[:3]
109
- if len(competing_emotions) > 1:
110
- if (competing_emotions[0][0] in neutral_emotions and
111
- competing_emotions[1][0] not in neutral_emotions and
112
- competing_emotions[1][1] > 0.7 * competing_emotions[0][1]):
113
- top_emotion = competing_emotions[1][0]
114
- if top_emotion in positive_emotions:
115
- sentiment = "POSITIVE"
116
- elif top_emotion in negative_emotions:
117
- sentiment = "NEGATIVE"
118
- else:
119
- sentiment = "NEUTRAL"
120
- else:
121
- sentiment = "NEUTRAL"
122
- else:
123
- sentiment = "NEUTRAL"
124
-
125
- print(f"Text: {text[:50]}...")
126
- print(f"Top 3 emotions: {sorted(filtered_emotions.items(), key=lambda x: x[1], reverse=True)[:3]}")
127
- print(f"Selected top emotion: {top_emotion} ({filtered_emotions.get(top_emotion, 0):.3f})")
128
- print(f"Sentiment determined: {sentiment}")
129
- print(f"All emotions detected: {emotions_dict}")
130
- print(f"Filtered emotions: {filtered_emotions}")
131
- print(f"Emotion classification threshold: 0.01")
132
-
133
  return emotions_dict, top_emotion, emotion_map, sentiment
134
  except Exception as e:
135
  st.error(f"Emotion detection failed: {str(e)}")
136
- print(f"Exception in emotion detection: {str(e)}")
137
  return {}, "neutral", {}, "NEUTRAL"
138
 
139
- # Sarcasm Detection Function
140
- @st.cache_resource
141
- def get_sarcasm_classifier():
142
- try:
143
- tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony", use_fast=True)
144
- model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
145
- model = model.to(device)
146
- classifier = pipeline("text-classification", model=model, tokenizer=tokenizer,
147
- device=0 if torch.cuda.is_available() else -1)
148
-
149
- # Add a verification test to ensure the model is working
150
- test_result = classifier("This is totally amazing")
151
- print(f"Sarcasm classifier test: {test_result}")
152
-
153
- return classifier
154
- except Exception as e:
155
- print(f"Error loading sarcasm model: {str(e)}")
156
- st.error(f"Failed to load sarcasm model. Please check logs.")
157
- return None
158
-
159
- def perform_sarcasm_detection(text):
160
  try:
161
- if not text or len(text.strip()) < 3:
162
- return False, 0.0
163
-
164
- sarcasm_classifier = get_sarcasm_classifier()
165
- if sarcasm_classifier is None:
166
- st.error("Sarcasm classifier not available.")
167
- return False, 0.0
168
-
169
  result = sarcasm_classifier(text)[0]
170
  is_sarcastic = result['label'] == "LABEL_1"
171
  sarcasm_score = result['score'] if is_sarcastic else 1 - result['score']
@@ -174,425 +87,227 @@ def perform_sarcasm_detection(text):
174
  st.error(f"Sarcasm detection failed: {str(e)}")
175
  return False, 0.0
176
 
177
- # Validate audio quality
178
  def validate_audio(audio_path):
179
  try:
180
- sound = AudioSegment.from_file(audio_path)
181
- if sound.dBFS < -55:
182
- st.warning("Audio volume is too low. Please record or upload a louder audio.")
183
  return False
184
- if len(sound) < 1000: # Less than 1 second
185
- st.warning("Audio is too short. Please record a longer audio.")
186
  return False
187
  return True
188
  except:
189
- st.error("Invalid or corrupted audio file.")
190
  return False
191
 
192
- # Speech Recognition with Whisper
193
- @st.cache_resource
194
- def load_whisper_model():
195
  try:
196
- model = whisper.load_model("large-v3")
197
- return model
198
- except Exception as e:
199
- print(f"Error loading Whisper model: {str(e)}")
200
- st.error(f"Failed to load Whisper model. Please check logs.")
201
- return None
202
-
203
- def transcribe_audio(audio_path, show_alternative=False):
204
- try:
205
- st.write(f"Processing audio file: {audio_path}")
206
- sound = AudioSegment.from_file(audio_path)
207
- st.write(
208
- f"Audio duration: {len(sound) / 1000:.2f}s, Sample rate: {sound.frame_rate}, Channels: {sound.channels}")
209
-
210
- # Convert to WAV format (16kHz, mono) for Whisper
211
- temp_wav_path = os.path.join(tempfile.gettempdir(), "temp_converted.wav")
212
- sound = sound.set_frame_rate(22050)
213
- sound = sound.set_channels(1)
214
- sound.export(temp_wav_path, format="wav")
215
-
216
- # Load Whisper model
217
- model = load_whisper_model()
218
-
219
- # Transcribe audio
220
- result = model.transcribe(temp_wav_path, language="en")
221
- main_text = result["text"].strip()
222
-
223
- # Clean up
224
- if os.path.exists(temp_wav_path):
225
- os.remove(temp_wav_path)
226
-
227
- # Whisper doesn't provide alternatives, so return empty list
228
- if show_alternative:
229
- return main_text, []
230
- return main_text
231
  except Exception as e:
232
  st.error(f"Transcription failed: {str(e)}")
233
- return "", [] if show_alternative else ""
234
 
235
- # Function to handle uploaded audio files
236
  def process_uploaded_audio(audio_file):
237
- if not audio_file:
238
- return None
239
-
240
  try:
241
- temp_dir = tempfile.gettempdir()
242
-
243
  ext = audio_file.name.split('.')[-1].lower()
244
  if ext not in ['wav', 'mp3', 'ogg']:
245
- st.error("Unsupported audio format. Please upload WAV, MP3, or OGG.")
246
  return None
247
- temp_file_path = os.path.join(temp_dir, f"uploaded_audio_{int(time.time())}.{ext}")
248
-
249
- with open(temp_file_path, "wb") as f:
250
- f.write(audio_file.getvalue())
251
-
252
  if not validate_audio(temp_file_path):
 
253
  return None
254
-
255
  return temp_file_path
256
  except Exception as e:
257
- st.error(f"Error processing uploaded audio: {str(e)}")
258
  return None
259
 
260
- # Show model information
261
- def show_model_info():
262
- st.sidebar.header("🧠 About the Models")
263
-
264
- model_tabs = st.sidebar.tabs(["Emotion", "Sarcasm", "Speech"])
265
-
266
- with model_tabs[0]:
267
- st.markdown("""
268
- *Emotion Model*: distilbert-base-uncased-emotion
269
- - Fine-tuned for six emotions (joy, anger, disgust, fear, sadness, surprise)
270
- - Architecture: DistilBERT base
271
- - High accuracy for basic emotion classification
272
- [πŸ” Model Hub](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion)
273
- """)
274
-
275
- with model_tabs[1]:
276
- st.markdown("""
277
- *Sarcasm Model*: cardiffnlp/twitter-roberta-base-irony
278
- - Trained on SemEval-2018 Task 3 (Twitter irony dataset)
279
- - Architecture: RoBERTa base
280
- - F1-score: 0.705
281
- [πŸ” Model Hub](https://huggingface.co/cardiffnlp/twitter-roberta-base-irony)
282
- """)
283
-
284
- with model_tabs[2]:
285
- st.markdown("""
286
- *Speech Recognition*: OpenAI Whisper (large-v3)
287
- - State-of-the-art model for speech-to-text
288
- - Accuracy: ~5-10% WER on clean English audio
289
- - Robust to noise, accents, and varied conditions
290
- - Runs locally, no internet required
291
- *Tips*: Use good mic, reduce noise, speak clearly
292
- [πŸ” Model Details](https://github.com/openai/whisper)
293
- """)
294
 
295
- # Custom audio recorder using HTML/JS
296
  def custom_audio_recorder():
297
- st.warning("Browser-based recording requires microphone access and a modern browser. If recording fails, try uploading an audio file instead.")
298
  audio_recorder_html = """
299
  <script>
300
- var audioRecorder = {
301
- audioBlobs: [],
302
- mediaRecorder: null,
303
- streamBeingCaptured: null,
304
- start: function() {
305
- if (!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia)) {
306
- return Promise.reject(new Error('mediaDevices API or getUserMedia method is not supported in this browser.'));
307
- }
308
- else {
309
- return navigator.mediaDevices.getUserMedia({ audio: true })
310
- .then(stream => {
311
- audioRecorder.streamBeingCaptured = stream;
312
- audioRecorder.mediaRecorder = new MediaRecorder(stream);
313
- audioRecorder.audioBlobs = [];
314
-
315
- audioRecorder.mediaRecorder.addEventListener("dataavailable", event => {
316
- audioRecorder.audioBlobs.push(event.data);
317
- });
318
-
319
- audioRecorder.mediaRecorder.start();
320
- });
321
- }
322
- },
323
- stop: function() {
324
- return new Promise(resolve => {
325
- let mimeType = audioRecorder.mediaRecorder.mimeType;
326
-
327
- audioRecorder.mediaRecorder.addEventListener("stop", () => {
328
- let audioBlob = new Blob(audioRecorder.audioBlobs, { type: mimeType });
329
- resolve(audioBlob);
330
- });
331
-
332
- audioRecorder.mediaRecorder.stop();
333
-
334
- audioRecorder.stopStream();
335
- audioRecorder.resetRecordingProperties();
336
- });
337
- },
338
- stopStream: function() {
339
- audioRecorder.streamBeingCaptured.getTracks()
340
- .forEach(track => track.stop());
341
- },
342
- resetRecordingProperties: function() {
343
- audioRecorder.mediaRecorder = null;
344
- audioRecorder.streamBeingCaptured = null;
345
  }
346
  }
347
- var isRecording = false;
348
- var recordButton = document.getElementById('record-button');
349
- var audioElement = document.getElementById('audio-playback');
350
- var audioData = document.getElementById('audio-data');
351
-
352
- function toggleRecording() {
353
- if (!isRecording) {
354
- audioRecorder.start()
355
- .then(() => {
356
- isRecording = true;
357
- recordButton.textContent = 'Stop Recording';
358
- recordButton.classList.add('recording');
359
- })
360
- .catch(error => {
361
- alert('Error starting recording: ' + error.message);
362
- });
363
- } else {
364
- audioRecorder.stop()
365
- .then(audioBlob => {
366
- const audioUrl = URL.createObjectURL(audioBlob);
367
- audioElement.src = audioUrl;
368
-
369
- const reader = new FileReader();
370
- reader.readAsDataURL(audioBlob);
371
- reader.onloadend = function() {
372
- const base64data = reader.result;
373
- audioData.value = base64data;
374
- const streamlitMessage = {type: "streamlit:setComponentValue", value: base64data};
375
- window.parent.postMessage(streamlitMessage, "*");
376
- }
377
 
378
- isRecording = false;
379
- recordButton.textContent = 'Start Recording';
380
- recordButton.classList.remove('recording');
381
- });
382
- }
383
  }
384
- document.addEventListener('DOMContentLoaded', function() {
385
- recordButton = document.getElementById('record-button');
386
- audioElement = document.getElementById('audio-playback');
387
- audioData = document.getElementById('audio-data');
388
 
389
- recordButton.addEventListener('click', toggleRecording);
390
- });
 
391
  </script>
392
- <div class="audio-recorder-container">
393
- <button id="record-button" class="record-button">Start Recording</button>
394
- <audio id="audio-playback" controls style="display:block; margin-top:10px;"></audio>
395
- <input type="hidden" id="audio-data" name="audio-data">
396
- </div>
397
  <style>
398
- .audio-recorder-container {
399
- display: flex;
400
- flex-direction: column;
401
- align-items: center;
402
- padding: 20px;
403
- }
404
- .record-button {
405
- background-color: #f63366;
406
- color: white;
407
- border: none;
408
- padding: 10px 20px;
409
- border-radius: 5px;
410
- cursor: pointer;
411
- font-size: 16px;
412
- }
413
- .record-button.recording {
414
- background-color: #ff0000;
415
- animation: pulse 1.5s infinite;
416
- }
417
- @keyframes pulse {
418
- 0% { opacity: 1; }
419
- 50% { opacity: 0.7; }
420
- 100% { opacity: 1; }
421
- }
422
  </style>
 
 
 
 
 
423
  """
424
-
425
  return components.html(audio_recorder_html, height=150)
426
 
427
- # Function to display analysis results
428
  def display_analysis_results(transcribed_text):
429
- st.session_state.debug_info = st.session_state.get('debug_info', [])
430
- st.session_state.debug_info.append(f"Processing text: {transcribed_text[:50]}...")
431
- st.session_state.debug_info = st.session_state.debug_info[-100:] # Keep last 100 entries
432
-
433
- emotions_dict, top_emotion, emotion_map, sentiment = perform_emotion_detection(transcribed_text)
434
- is_sarcastic, sarcasm_score = perform_sarcasm_detection(transcribed_text)
 
 
 
 
 
 
 
435
 
436
- # Add results to debug info
437
- st.session_state.debug_info.append(f"Top emotion: {top_emotion}, Sentiment: {sentiment}")
438
- st.session_state.debug_info.append(f"Sarcasm: {is_sarcastic}, Score: {sarcasm_score:.3f}")
439
-
440
- st.header("Transcribed Text")
441
- st.text_area("Text", transcribed_text, height=150, disabled=True, help="The audio converted to text.")
442
-
443
- confidence_score = min(0.95, max(0.70, len(transcribed_text.split()) / 50))
444
- st.caption(f"Estimated transcription confidence: {confidence_score:.2f} (based on text length)")
445
-
446
- st.header("Analysis Results")
447
  col1, col2 = st.columns([1, 2])
448
-
449
  with col1:
450
  st.subheader("Sentiment")
451
  sentiment_icon = "πŸ‘" if sentiment == "POSITIVE" else "πŸ‘Ž" if sentiment == "NEGATIVE" else "😐"
452
- st.markdown(f"{sentiment_icon} {sentiment.capitalize()}** (Based on {top_emotion})")
453
- st.info("Sentiment reflects the dominant emotion's tone.")
454
-
455
  st.subheader("Sarcasm")
456
  sarcasm_icon = "😏" if is_sarcastic else "😐"
457
- sarcasm_text = "Detected" if is_sarcastic else "Not Detected"
458
- st.markdown(f"{sarcasm_icon} {sarcasm_text}** (Score: {sarcasm_score:.3f})")
459
- st.info("Score indicates sarcasm confidence (0 to 1).")
460
-
461
  with col2:
462
  st.subheader("Emotions")
463
  if emotions_dict:
464
- st.markdown(
465
- f"*Dominant:* {emotion_map.get(top_emotion, '❓')} {top_emotion.capitalize()} (Score: {emotions_dict[top_emotion]:.3f})")
466
- sorted_emotions = sorted(emotions_dict.items(), key=lambda x: x[1], reverse=True)
467
- top_emotions = sorted_emotions[:8]
468
- emotions = [e[0] for e in top_emotions]
469
- scores = [e[1] for e in top_emotions]
470
- fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'},
471
- title="Top Emotions Distribution", color=emotions,
472
- color_discrete_sequence=px.colors.qualitative.Bold)
473
- fig.update_layout(yaxis_range=[0, 1], showlegend=False, title_font_size=14)
474
  st.plotly_chart(fig, use_container_width=True)
475
  else:
476
  st.write("No emotions detected.")
477
 
478
- with st.expander("Debug Information", expanded=False):
479
- st.write("Debugging information for troubleshooting:")
480
- for i, debug_line in enumerate(st.session_state.debug_info[-10:]):
481
- st.text(f"{i + 1}. {debug_line}")
482
- if emotions_dict:
483
- st.write("Raw emotion scores:")
484
- for emotion, score in sorted(emotions_dict.items(), key=lambda x: x[1], reverse=True):
485
- if score > 0.01: # Only show non-negligible scores
486
- st.text(f"{emotion}: {score:.4f}")
487
-
488
- with st.expander("Analysis Details", expanded=False):
489
- st.write("""
490
- *How this works:*
491
- 1. *Speech Recognition*: Audio transcribed using OpenAI Whisper (large-v3)
492
- 2. *Emotion Analysis*: DistilBERT model trained for six emotions
493
- 3. *Sentiment Analysis*: Derived from dominant emotion
494
- 4. *Sarcasm Detection*: RoBERTa model for irony detection
495
- *Accuracy depends on*:
496
- - Audio quality
497
- - Speech clarity
498
- - Background noise
499
- - Speech patterns
500
  """)
501
 
502
- # Process base64 audio data
503
- def process_base64_audio(base64_data):
504
- try:
505
- base64_binary = base64_data.split(',')[1]
506
- binary_data = base64.b64decode(base64_binary)
507
-
508
- temp_dir = tempfile.gettempdir()
509
- temp_file_path = os.path.join(temp_dir, f"recording_{int(time.time())}.wav")
510
-
511
- with open(temp_file_path, "wb") as f:
512
- f.write(binary_data)
513
-
514
- if not validate_audio(temp_file_path):
515
- return None
516
-
517
- return temp_file_path
518
- except Exception as e:
519
- st.error(f"Error processing audio data: {str(e)}")
520
- return None
521
-
522
- # Main App Logic
523
  def main():
524
  if 'debug_info' not in st.session_state:
525
  st.session_state.debug_info = []
526
 
527
- tab1, tab2 = st.tabs(["πŸ“ Upload Audio", "πŸŽ™ Record Audio"])
528
-
529
  with tab1:
530
- st.header("Upload an Audio File")
531
- audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"],
532
- help="Upload an audio file for analysis")
533
-
534
  if audio_file:
535
  st.audio(audio_file.getvalue())
536
- st.caption("🎧 Uploaded Audio Playback")
537
-
538
- upload_button = st.button("Analyze Upload", key="analyze_upload")
539
-
540
- if upload_button:
541
- with st.spinner('Analyzing audio with advanced precision...'):
542
- temp_audio_path = process_uploaded_audio(audio_file)
543
- if temp_audio_path:
544
- main_text, alternatives = transcribe_audio(temp_audio_path, show_alternative=True)
545
-
546
- if main_text:
547
- if alternatives:
548
- with st.expander("Alternative transcriptions detected", expanded=False):
549
- for i, alt in enumerate(alternatives[:3], 1):
550
- st.write(f"{i}. {alt}")
551
-
552
- display_analysis_results(main_text)
553
- else:
554
- st.error("Could not transcribe the audio. Please try again with clearer audio.")
555
-
556
- if os.path.exists(temp_audio_path):
557
- os.remove(temp_audio_path)
558
-
559
  with tab2:
560
- st.header("Record Your Voice")
561
- st.write("Use the recorder below to analyze your speech in real-time.")
562
-
563
- st.subheader("Browser-Based Recorder")
564
- st.write("Click the button below to start/stop recording.")
565
-
566
  audio_data = custom_audio_recorder()
567
-
568
- if audio_data:
569
- analyze_rec_button = st.button("Analyze Recording", key="analyze_rec")
570
-
571
- if analyze_rec_button:
572
- with st.spinner("Processing your recording..."):
573
- temp_audio_path = process_base64_audio(audio_data)
574
-
575
- if temp_audio_path:
576
- transcribed_text = transcribe_audio(temp_audio_path)
577
-
578
- if transcribed_text:
579
- display_analysis_results(transcribed_text)
580
- else:
581
- st.error("Could not transcribe the audio. Please try speaking more clearly.")
582
-
583
- if os.path.exists(temp_audio_path):
584
- os.remove(temp_audio_path)
585
-
586
- st.subheader("Manual Text Input")
587
- st.write("If recording doesn't work, you can type your text here:")
588
-
589
- manual_text = st.text_area("Enter text to analyze:", placeholder="Type what you want to analyze...")
590
- analyze_text_button = st.button("Analyze Text", key="analyze_manual")
591
-
592
- if analyze_text_button and manual_text:
593
  display_analysis_results(manual_text)
594
 
595
- show_model_info()
596
-
597
  if __name__ == "__main__":
598
- main()
 
 
2
  import streamlit as st
3
  import tempfile
4
  import torch
5
+ import torchaudio
6
  import transformers
7
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
8
  import plotly.express as px
9
  import logging
10
  import warnings
11
  import whisper
 
 
12
  import base64
13
  import io
14
+ import asyncio
15
+ from concurrent.futures import ThreadPoolExecutor
16
  import streamlit.components.v1 as components
17
 
18
+ # Suppress warnings
19
+ logging.getLogger("torch").setLevel(logging.ERROR)
20
+ logging.getLogger("transformers").setLevel(logging.ERROR)
21
  warnings.filterwarnings("ignore")
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
+ # Device setup
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ st.write(f"Using device: {device}")
27
 
28
+ # Streamlit config
29
+ st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis")
30
+ st.title("πŸŽ™ Voice Sentiment Analysis")
31
+ st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from voice or text.")
32
 
33
+ # Global model cache
 
 
 
 
34
  @st.cache_resource
35
+ def load_models():
36
+ whisper_model = whisper.load_model("base")
37
+
38
+ emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
39
+ emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
40
+ emotion_model = emotion_model.to(device).half()
41
+ emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
42
+ top_k=None, device=0 if torch.cuda.is_available() else -1)
43
+
44
+ sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
45
+ sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
46
+ sarcasm_model = sarcasm_model.to(device).half()
47
+ sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
48
+ device=0 if torch.cuda.is_available() else -1)
49
+
50
+ return whisper_model, emotion_classifier, sarcasm_classifier
51
+
52
+ whisper_model, emotion_classifier, sarcasm_classifier = load_models()
53
+
54
+ # Emotion detection
55
+ async def perform_emotion_detection(text):
56
+ if not text or len(text.strip()) < 3:
57
+ return {}, "neutral", {}, "NEUTRAL"
58
+
59
  try:
60
+ results = emotion_classifier(text)[0]
61
+ emotions_dict = {r['label']: r['score'] for r in results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
 
 
 
 
63
  top_emotion = max(filtered_emotions, key=filtered_emotions.get)
64
+
65
+ positive_emotions = ["joy"]
66
+ negative_emotions = ["anger", "disgust", "fear", "sadness"]
67
+ sentiment = ("POSITIVE" if top_emotion in positive_emotions else
68
+ "NEGATIVE" if top_emotion in negative_emotions else "NEUTRAL")
69
+
70
+ emotion_map = {"joy": "😊", "anger": "😑", "disgust": "🀒", "fear": "😨", "sadness": "😭", "surprise": "😲"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  return emotions_dict, top_emotion, emotion_map, sentiment
72
  except Exception as e:
73
  st.error(f"Emotion detection failed: {str(e)}")
 
74
  return {}, "neutral", {}, "NEUTRAL"
75
 
76
+ # Sarcasm detection
77
+ async def perform_sarcasm_detection(text):
78
+ if not text or len(text.strip()) < 3:
79
+ return False, 0.0
80
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  try:
 
 
 
 
 
 
 
 
82
  result = sarcasm_classifier(text)[0]
83
  is_sarcastic = result['label'] == "LABEL_1"
84
  sarcasm_score = result['score'] if is_sarcastic else 1 - result['score']
 
87
  st.error(f"Sarcasm detection failed: {str(e)}")
88
  return False, 0.0
89
 
90
+ # Audio validation
91
  def validate_audio(audio_path):
92
  try:
93
+ waveform, sample_rate = torchaudio.load(audio_path)
94
+ if waveform.abs().max() < 0.01:
95
+ st.warning("Audio volume too low.")
96
  return False
97
+ if waveform.shape[1] / sample_rate < 1:
98
+ st.warning("Audio too short.")
99
  return False
100
  return True
101
  except:
102
+ st.error("Invalid audio file.")
103
  return False
104
 
105
+ # Audio transcription
106
+ @st.cache_data
107
+ def transcribe_audio(audio_path):
108
  try:
109
+ waveform, sample_rate = torchaudio.load(audio_path)
110
+ if sample_rate != 16000:
111
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
112
+ waveform = resampler(waveform)
113
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
114
+ torchaudio.save(temp_file.name, waveform, 16000)
115
+ result = whisper_model.transcribe(temp_file.name, language="en")
116
+ os.remove(temp_file.name)
117
+ return result["text"].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
  st.error(f"Transcription failed: {str(e)}")
120
+ return ""
121
 
122
+ # Process uploaded audio
123
  def process_uploaded_audio(audio_file):
 
 
 
124
  try:
 
 
125
  ext = audio_file.name.split('.')[-1].lower()
126
  if ext not in ['wav', 'mp3', 'ogg']:
127
+ st.error("Unsupported format. Use WAV, MP3, or OGG.")
128
  return None
129
+ with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as temp_file:
130
+ temp_file.write(audio_file.getvalue())
131
+ temp_file_path = temp_file.name
 
 
132
  if not validate_audio(temp_file_path):
133
+ os.remove(temp_file_path)
134
  return None
 
135
  return temp_file_path
136
  except Exception as e:
137
+ st.error(f"Error processing audio: {str(e)}")
138
  return None
139
 
140
+ # Process base64 audio
141
+ def process_base64_audio(base64_data):
142
+ try:
143
+ base64_binary = base64_data.split(',')[1]
144
+ binary_data = base64.b64decode(base64_binary)
145
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
146
+ temp_file.write(binary_data)
147
+ temp_file_path = temp_file.name
148
+ if not validate_audio(temp_file_path):
149
+ os.remove(temp_file_path)
150
+ return None
151
+ return temp_file_path
152
+ except Exception as e:
153
+ st.error(f"Error processing audio data: {str(e)}")
154
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ # Custom audio recorder
157
  def custom_audio_recorder():
 
158
  audio_recorder_html = """
159
  <script>
160
+ let recorder, audioBlob, isRecording = false;
161
+ const recordButton = document.getElementById('record-button');
162
+ const audioPlayback = document.getElementById('audio-playback');
163
+ const audioData = document.getElementById('audio-data');
164
+
165
+ async function startRecording() {
166
+ try {
167
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
168
+ recorder = new MediaRecorder(stream);
169
+ const chunks = [];
170
+ recorder.ondataavailable = e => chunks.push(e.data);
171
+ recorder.onstop = () => {
172
+ audioBlob = new Blob(chunks, { type: 'audio/wav' });
173
+ audioPlayback.src = URL.createObjectURL(audioBlob);
174
+ const reader = new FileReader();
175
+ reader.readAsDataURL(audioBlob);
176
+ reader.onloadend = () => {
177
+ audioData.value = reader.result;
178
+ window.parent.postMessage({type: "streamlit:setComponentValue", value: reader.result}, "*");
179
+ };
180
+ stream.getTracks().forEach(track => track.stop());
181
+ };
182
+ recorder.start();
183
+ isRecording = true;
184
+ recordButton.textContent = 'Stop Recording';
185
+ recordButton.classList.add('recording');
186
+ } catch (e) {
187
+ alert('Recording failed: ' + e.message);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  }
189
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ function stopRecording() {
192
+ recorder.stop();
193
+ isRecording = false;
194
+ recordButton.textContent = 'Start Recording';
195
+ recordButton.classList.remove('recording');
196
  }
 
 
 
 
197
 
198
+ document.getElementById('record-button').onclick = () => {
199
+ isRecording ? stopRecording() : startRecording();
200
+ };
201
  </script>
 
 
 
 
 
202
  <style>
203
+ .recorder-container { text-align: center; padding: 15px; }
204
+ .record-button { background: #ff4b4b; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer; }
205
+ .record-button.recording { background: #d32f2f; animation: pulse 1.5s infinite; }
206
+ @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.7; } 100% { opacity: 1; } }
207
+ audio { margin-top: 10px; width: 100%; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  </style>
209
+ <div class="recorder-container">
210
+ <button id="record-button">Start Recording</button>
211
+ <audio id="audio-playback" controls></audio>
212
+ <input type="hidden" id="audio-data">
213
+ </div>
214
  """
 
215
  return components.html(audio_recorder_html, height=150)
216
 
217
+ # Display results
218
  def display_analysis_results(transcribed_text):
219
+ async def run_analyses():
220
+ emotion_task = perform_emotion_detection(transcribed_text)
221
+ sarcasm_task = perform_sarcasm_detection(transcribed_text)
222
+ return await asyncio.gather(emotion_task, sarcasm_task)
223
+
224
+ with st.spinner("Analyzing..."):
225
+ with ThreadPoolExecutor() as executor:
226
+ loop = asyncio.get_event_loop()
227
+ (emotions_dict, top_emotion, emotion_map, sentiment), (is_sarcastic, sarcasm_score) = loop.run_until_complete(run_analyses())
228
+
229
+ st.header("Results")
230
+ st.subheader("Transcribed Text")
231
+ st.text_area("Text", transcribed_text, height=100, disabled=True)
232
 
 
 
 
 
 
 
 
 
 
 
 
233
  col1, col2 = st.columns([1, 2])
 
234
  with col1:
235
  st.subheader("Sentiment")
236
  sentiment_icon = "πŸ‘" if sentiment == "POSITIVE" else "πŸ‘Ž" if sentiment == "NEGATIVE" else "😐"
237
+ st.markdown(f"{sentiment_icon} **{sentiment}**")
238
+
 
239
  st.subheader("Sarcasm")
240
  sarcasm_icon = "😏" if is_sarcastic else "😐"
241
+ st.markdown(f"{sarcasm_icon} **{'Detected' if is_sarcastic else 'Not Detected'}** (Score: {sarcasm_score:.2f})")
242
+
 
 
243
  with col2:
244
  st.subheader("Emotions")
245
  if emotions_dict:
246
+ st.markdown(f"*Dominant:* {emotion_map.get(top_emotion, '❓')} **{top_emotion.capitalize()}** ({emotions_dict[top_emotion]:.2f})")
247
+ emotions = list(emotions_dict.keys())[:5]
248
+ scores = list(emotions_dict.values())[:5]
249
+ fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'}, color=emotions,
250
+ color_discrete_sequence=px.colors.qualitative.Set2)
251
+ fig.update_layout(yaxis_range=[0, 1], showlegend=False, height=300)
 
 
 
 
252
  st.plotly_chart(fig, use_container_width=True)
253
  else:
254
  st.write("No emotions detected.")
255
 
256
+ with st.expander("Details"):
257
+ st.markdown("""
258
+ - **Speech**: Whisper-base (fast, ~10-15% WER)
259
+ - **Emotions**: DistilBERT (joy, anger, etc.)
260
+ - **Sarcasm**: RoBERTa (irony detection)
261
+ - **Tips**: Clear audio, minimal noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  """)
263
 
264
+ # Main app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def main():
266
  if 'debug_info' not in st.session_state:
267
  st.session_state.debug_info = []
268
 
269
+ tab1, tab2, tab3 = st.tabs(["πŸ“ Upload Audio", "πŸŽ™ Record Audio", "✍️ Text Input"])
270
+
271
  with tab1:
272
+ audio_file = st.file_uploader("Upload audio", type=["wav", "mp3", "ogg"])
 
 
 
273
  if audio_file:
274
  st.audio(audio_file.getvalue())
275
+ if st.button("Analyze", key="upload_analyze"):
276
+ progress = st.progress(0)
277
+ temp_path = process_uploaded_audio(audio_file)
278
+ if temp_path:
279
+ progress.progress(50)
280
+ text = transcribe_audio(temp_path)
281
+ if text:
282
+ progress.progress(100)
283
+ display_analysis_results(text)
284
+ else:
285
+ st.error("Transcription failed.")
286
+ os.remove(temp_path)
287
+ progress.empty()
288
+
 
 
 
 
 
 
 
 
 
289
  with tab2:
290
+ st.markdown("Record audio using your microphone.")
 
 
 
 
 
291
  audio_data = custom_audio_recorder()
292
+ if audio_data and st.button("Analyze", key="record_analyze"):
293
+ progress = st.progress(0)
294
+ temp_path = process_base64_audio(audio_data)
295
+ if temp_path:
296
+ progress.progress(50)
297
+ text = transcribe_audio(temp_path)
298
+ if text:
299
+ progress.progress(100)
300
+ display_analysis_results(text)
301
+ else:
302
+ st.error("Transcription failed.")
303
+ os.remove(temp_path)
304
+ progress.empty()
305
+
306
+ with tab3:
307
+ manual_text = st.text_area("Enter text:", placeholder="Type text to analyze...")
308
+ if st.button("Analyze", key="text_analyze") and manual_text:
 
 
 
 
 
 
 
 
 
309
  display_analysis_results(manual_text)
310
 
 
 
311
  if __name__ == "__main__":
312
+ main()
313
+ torch.cuda.empty_cache()