|
|
|
import os |
|
import streamlit as st |
|
import tempfile |
|
import torch |
|
import transformers |
|
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer |
|
import plotly.express as px |
|
import logging |
|
import warnings |
|
import whisper |
|
from pydub import AudioSegment |
|
import time |
|
import base64 |
|
import io |
|
import streamlit.components.v1 as components |
|
|
|
|
|
logging.getLogger("torch").setLevel(logging.CRITICAL) |
|
logging.getLogger("transformers").setLevel(logging.CRITICAL) |
|
warnings.filterwarnings("ignore") |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
st.set_page_config(layout="wide", page_title="Voice Based Sentiment Analysis") |
|
|
|
|
|
st.title("ποΈ Voice Based Sentiment Analysis") |
|
st.write("Detect emotions, sentiment, and sarcasm from your voice with state-of-the-art accuracy using OpenAI Whisper.") |
|
|
|
|
|
@st.cache_resource |
|
def get_emotion_classifier(): |
|
tokenizer = AutoTokenizer.from_pretrained("SamLowe/roberta-base-go_emotions", use_fast=True) |
|
model = AutoModelForSequenceClassification.from_pretrained("SamLowe/roberta-base-go_emotions") |
|
model = model.to(device) |
|
return pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=-1 if device.type == "cpu" else 0) |
|
|
|
def perform_emotion_detection(text): |
|
try: |
|
emotion_classifier = get_emotion_classifier() |
|
emotion_results = emotion_classifier(text)[0] |
|
|
|
emotion_map = { |
|
"admiration": "π€©", "amusement": "π", "anger": "π‘", "annoyance": "π", |
|
"approval": "π", "caring": "π€", "confusion": "π", "curiosity": "π§", |
|
"desire": "π", "disappointment": "π", "disapproval": "π", "disgust": "π€’", |
|
"embarrassment": "π³", "excitement": "π€©", "fear": "π¨", "gratitude": "π", |
|
"grief": "π’", "joy": "π", "love": "β€οΈ", "nervousness": "π°", |
|
"optimism": "π", "pride": "π", "realization": "π‘", "relief": "π", |
|
"remorse": "π", "sadness": "π", "surprise": "π²", "neutral": "π" |
|
} |
|
|
|
positive_emotions = ["admiration", "amusement", "approval", "caring", "desire", |
|
"excitement", "gratitude", "joy", "love", "optimism", "pride", "relief"] |
|
negative_emotions = ["anger", "annoyance", "disappointment", "disapproval", "disgust", |
|
"embarrassment", "fear", "grief", "nervousness", "remorse", "sadness"] |
|
neutral_emotions = ["confusion", "curiosity", "realization", "surprise", "neutral"] |
|
|
|
emotions_dict = {result['label']: result['score'] for result in emotion_results} |
|
top_emotion = max(emotions_dict, key=emotions_dict.get) |
|
|
|
if top_emotion in positive_emotions: |
|
sentiment = "POSITIVE" |
|
elif top_emotion in negative_emotions: |
|
sentiment = "NEGATIVE" |
|
else: |
|
sentiment = "NEUTRAL" |
|
|
|
return emotions_dict, top_emotion, emotion_map, sentiment |
|
except Exception as e: |
|
st.error(f"Emotion detection failed: {str(e)}") |
|
return {}, "unknown", {}, "UNKNOWN" |
|
|
|
|
|
@st.cache_resource |
|
def get_sarcasm_classifier(): |
|
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony", use_fast=True) |
|
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony") |
|
model = model.to(device) |
|
return pipeline("text-classification", model=model, tokenizer=tokenizer, device=-1 if device.type == "cpu" else 0) |
|
|
|
def perform_sarcasm_detection(text): |
|
try: |
|
sarcasm_classifier = get_sarcasm_classifier() |
|
result = sarcasm_classifier(text)[0] |
|
is_sarcastic = result['label'] == "LABEL_1" |
|
sarcasm_score = result['score'] if is_sarcastic else 1 - result['score'] |
|
return is_sarcastic, sarcasm_score |
|
except Exception as e: |
|
st.error(f"Sarcasm detection failed: {str(e)}") |
|
return False, 0.0 |
|
|
|
|
|
def validate_audio(audio_path): |
|
try: |
|
sound = AudioSegment.from_file(audio_path) |
|
if sound.dBFS < -50: |
|
st.warning("Audio volume is too low. Please record or upload a louder audio.") |
|
return False |
|
if len(sound) < 1000: |
|
st.warning("Audio is too short. Please record a longer audio.") |
|
return False |
|
return True |
|
except: |
|
st.error("Invalid or corrupted audio file.") |
|
return False |
|
|
|
|
|
@st.cache_resource |
|
def load_whisper_model(): |
|
|
|
model = whisper.load_model("large-v3") |
|
return model |
|
|
|
def transcribe_audio(audio_path, show_alternative=False): |
|
try: |
|
st.write(f"Processing audio file: {audio_path}") |
|
sound = AudioSegment.from_file(audio_path) |
|
st.write(f"Audio duration: {len(sound)/1000:.2f}s, Sample rate: {sound.frame_rate}, Channels: {sound.channels}") |
|
|
|
|
|
temp_wav_path = os.path.join(tempfile.gettempdir(), "temp_converted.wav") |
|
sound = sound.set_frame_rate(16000) |
|
sound = sound.set_channels(1) |
|
sound.export(temp_wav_path, format="wav") |
|
|
|
|
|
model = load_whisper_model() |
|
|
|
|
|
result = model.transcribe(temp_wav_path, language="en") |
|
main_text = result["text"].strip() |
|
|
|
|
|
if os.path.exists(temp_wav_path): |
|
os.remove(temp_wav_path) |
|
|
|
|
|
if show_alternative: |
|
return main_text, [] |
|
return main_text |
|
except Exception as e: |
|
st.error(f"Transcription failed: {str(e)}") |
|
return "", [] if show_alternative else "" |
|
|
|
|
|
def process_uploaded_audio(audio_file): |
|
if not audio_file: |
|
return None |
|
|
|
try: |
|
temp_dir = tempfile.gettempdir() |
|
temp_file_path = os.path.join(temp_dir, f"uploaded_audio_{int(time.time())}.wav") |
|
|
|
with open(temp_file_path, "wb") as f: |
|
f.write(audio_file.getvalue()) |
|
|
|
if not validate_audio(temp_file_path): |
|
return None |
|
|
|
return temp_file_path |
|
except Exception as e: |
|
st.error(f"Error processing uploaded audio: {str(e)}") |
|
return None |
|
|
|
|
|
def show_model_info(): |
|
st.sidebar.header("π§ About the Models") |
|
|
|
model_tabs = st.sidebar.tabs(["Emotion", "Sarcasm", "Speech"]) |
|
|
|
with model_tabs[0]: |
|
st.markdown(""" |
|
**Emotion Model**: SamLowe/roberta-base-go_emotions |
|
- Fine-tuned on GoEmotions dataset (58k Reddit comments, 27 emotions) |
|
- Architecture: RoBERTa base |
|
- Micro-F1: 0.46 |
|
[π Model Hub](https://huggingface.co/SamLowe/roberta-base-go_emotions) |
|
""") |
|
|
|
with model_tabs[1]: |
|
st.markdown(""" |
|
**Sarcasm Model**: cardiffnlp/twitter-roberta-base-irony |
|
- Trained on SemEval-2018 Task 3 (Twitter irony dataset) |
|
- Architecture: RoBERTa base |
|
- F1-score: 0.705 |
|
[π Model Hub](https://huggingface.co/cardiffnlp/twitter-roberta-base-irony) |
|
""") |
|
|
|
with model_tabs[2]: |
|
st.markdown(""" |
|
**Speech Recognition**: OpenAI Whisper (large-v3) |
|
- State-of-the-art model for speech-to-text |
|
- Accuracy: ~5-10% WER on clean English audio |
|
- Robust to noise, accents, and varied conditions |
|
- Runs locally, no internet required |
|
**Tips**: Use good mic, reduce noise, speak clearly |
|
[π Model Details](https://github.com/openai/whisper) |
|
""") |
|
|
|
|
|
def custom_audio_recorder(): |
|
audio_recorder_html = """ |
|
<script> |
|
var audioRecorder = { |
|
audioBlobs: [], |
|
mediaRecorder: null, |
|
streamBeingCaptured: null, |
|
start: function() { |
|
if (!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia)) { |
|
return Promise.reject(new Error('mediaDevices API or getUserMedia method is not supported in this browser.')); |
|
} |
|
else { |
|
return navigator.mediaDevices.getUserMedia({ audio: true }) |
|
.then(stream => { |
|
audioRecorder.streamBeingCaptured = stream; |
|
audioRecorder.mediaRecorder = new MediaRecorder(stream); |
|
audioRecorder.audioBlobs = []; |
|
|
|
audioRecorder.mediaRecorder.addEventListener("dataavailable", event => { |
|
audioRecorder.audioBlobs.push(event.data); |
|
}); |
|
|
|
audioRecorder.mediaRecorder.start(); |
|
}); |
|
} |
|
}, |
|
stop: function() { |
|
return new Promise(resolve => { |
|
let mimeType = audioRecorder.mediaRecorder.mimeType; |
|
|
|
audioRecorder.mediaRecorder.addEventListener("stop", () => { |
|
let audioBlob = new Blob(audioRecorder.audioBlobs, { type: mimeType }); |
|
resolve(audioBlob); |
|
}); |
|
|
|
audioRecorder.mediaRecorder.stop(); |
|
|
|
audioRecorder.stopStream(); |
|
audioRecorder.resetRecordingProperties(); |
|
}); |
|
}, |
|
stopStream: function() { |
|
audioRecorder.streamBeingCaptured.getTracks() |
|
.forEach(track => track.stop()); |
|
}, |
|
resetRecordingProperties: function() { |
|
audioRecorder.mediaRecorder = null; |
|
audioRecorder.streamBeingCaptured = null; |
|
} |
|
} |
|
|
|
var isRecording = false; |
|
var recordButton = document.getElementById('record-button'); |
|
var audioElement = document.getElementById('audio-playback'); |
|
var audioData = document.getElementById('audio-data'); |
|
|
|
function toggleRecording() { |
|
if (!isRecording) { |
|
audioRecorder.start() |
|
.then(() => { |
|
isRecording = true; |
|
recordButton.textContent = 'Stop Recording'; |
|
recordButton.classList.add('recording'); |
|
}) |
|
.catch(error => { |
|
alert('Error starting recording: ' + error.message); |
|
}); |
|
} else { |
|
audioRecorder.stop() |
|
.then(audioBlob => { |
|
const audioUrl = URL.createObjectURL(audioBlob); |
|
audioElement.src = audioUrl; |
|
|
|
const reader = new FileReader(); |
|
reader.readAsDataURL(audioBlob); |
|
reader.onloadend = function() { |
|
const base64data = reader.result; |
|
audioData.value = base64data; |
|
const streamlitMessage = {type: "streamlit:setComponentValue", value: base64data}; |
|
window.parent.postMessage(streamlitMessage, "*"); |
|
} |
|
|
|
isRecording = false; |
|
recordButton.textContent = 'Start Recording'; |
|
recordButton.classList.remove('recording'); |
|
}); |
|
} |
|
} |
|
|
|
document.addEventListener('DOMContentLoaded', function() { |
|
recordButton = document.getElementById('record-button'); |
|
audioElement = document.getElementById('audio-playback'); |
|
audioData = document.getElementById('audio-data'); |
|
|
|
recordButton.addEventListener('click', toggleRecording); |
|
}); |
|
</script> |
|
|
|
<div class="audio-recorder-container"> |
|
<button id="record-button" class="record-button">Start Recording</button> |
|
<audio id="audio-playback" controls style="display:block; margin-top:10px;"></audio> |
|
<input type="hidden" id="audio-data" name="audio-data"> |
|
</div> |
|
|
|
<style> |
|
.audio-recorder-container { |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
padding: 20px; |
|
} |
|
.record-button { |
|
background-color: #f63366; |
|
color: white; |
|
border: none; |
|
padding: 10px 20px; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
font-size: 16px; |
|
} |
|
.record-button.recording { |
|
background-color: #ff0000; |
|
animation: pulse 1.5s infinite; |
|
} |
|
@keyframes pulse { |
|
0% { opacity: 1; } |
|
50% { opacity: 0.7; } |
|
100% { opacity: 1; } |
|
} |
|
</style> |
|
""" |
|
|
|
return components.html(audio_recorder_html, height=150) |
|
|
|
|
|
def display_analysis_results(transcribed_text): |
|
emotions_dict, top_emotion, emotion_map, sentiment = perform_emotion_detection(transcribed_text) |
|
is_sarcastic, sarcasm_score = perform_sarcasm_detection(transcribed_text) |
|
|
|
st.header("Transcribed Text") |
|
st.text_area("Text", transcribed_text, height=150, disabled=True, help="The audio converted to text.") |
|
|
|
confidence_score = min(0.95, max(0.70, len(transcribed_text.split()) / 50)) |
|
st.caption(f"Transcription confidence: {confidence_score:.2f}") |
|
|
|
st.header("Analysis Results") |
|
col1, col2 = st.columns([1, 2]) |
|
|
|
with col1: |
|
st.subheader("Sentiment") |
|
sentiment_icon = "π" if sentiment == "POSITIVE" else "π" if sentiment == "NEGATIVE" else "π" |
|
st.markdown(f"**{sentiment_icon} {sentiment.capitalize()}** (Based on {top_emotion})") |
|
st.info("Sentiment reflects the dominant emotion's tone.") |
|
|
|
st.subheader("Sarcasm") |
|
sarcasm_icon = "π" if is_sarcastic else "π" |
|
sarcasm_text = "Detected" if is_sarcastic else "Not Detected" |
|
st.markdown(f"**{sarcasm_icon} {sarcasm_text}** (Score: {sarcasm_score:.3f})") |
|
st.info("Score indicates sarcasm confidence (0 to 1).") |
|
|
|
with col2: |
|
st.subheader("Emotions") |
|
if emotions_dict: |
|
st.markdown(f"**Dominant:** {emotion_map.get(top_emotion, 'β')} {top_emotion.capitalize()} (Score: {emotions_dict[top_emotion]:.3f})") |
|
sorted_emotions = sorted(emotions_dict.items(), key=lambda x: x[1], reverse=True) |
|
top_emotions = sorted_emotions[:8] |
|
emotions = [e[0] for e in top_emotions] |
|
scores = [e[1] for e in top_emotions] |
|
fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'}, |
|
title="Top Emotions Distribution", color=emotions, |
|
color_discrete_sequence=px.colors.qualitative.Bold) |
|
fig.update_layout(yaxis_range=[0, 1], showlegend=False, title_font_size=14) |
|
st.plotly_chart(fig, use_container_width=True) |
|
else: |
|
st.write("No emotions detected.") |
|
|
|
with st.expander("Analysis Details", expanded=False): |
|
st.write(""" |
|
**How this works:** |
|
1. **Speech Recognition**: Audio transcribed using OpenAI Whisper (large-v3) |
|
2. **Emotion Analysis**: RoBERTa model trained on GoEmotions (27 emotions) |
|
3. **Sentiment Analysis**: Derived from dominant emotion |
|
4. **Sarcasm Detection**: RoBERTa model for irony detection |
|
**Accuracy depends on**: |
|
- Audio quality |
|
- Speech clarity |
|
- Background noise |
|
- Speech patterns |
|
""") |
|
|
|
|
|
def process_base64_audio(base64_data): |
|
try: |
|
base64_binary = base64_data.split(',')[1] |
|
binary_data = base64.b64decode(base64_binary) |
|
|
|
temp_dir = tempfile.gettempdir() |
|
temp_file_path = os.path.join(temp_dir, f"recording_{int(time.time())}.wav") |
|
|
|
with open(temp_file_path, "wb") as f: |
|
f.write(binary_data) |
|
|
|
if not validate_audio(temp_file_path): |
|
return None |
|
|
|
return temp_file_path |
|
except Exception as e: |
|
st.error(f"Error processing audio data: {str(e)}") |
|
return None |
|
|
|
|
|
def main(): |
|
tab1, tab2 = st.tabs(["π Upload Audio", "ποΈ Record Audio"]) |
|
|
|
with tab1: |
|
st.header("Upload an Audio File") |
|
audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"], |
|
help="Upload an audio file for analysis") |
|
|
|
if audio_file: |
|
st.audio(audio_file.getvalue()) |
|
st.caption("π§ Uploaded Audio Playback") |
|
|
|
upload_button = st.button("Analyze Upload", key="analyze_upload") |
|
|
|
if upload_button: |
|
with st.spinner('Analyzing audio with advanced precision...'): |
|
temp_audio_path = process_uploaded_audio(audio_file) |
|
if temp_audio_path: |
|
main_text, alternatives = transcribe_audio(temp_audio_path, show_alternative=True) |
|
|
|
if main_text: |
|
if alternatives: |
|
with st.expander("Alternative transcriptions detected", expanded=False): |
|
for i, alt in enumerate(alternatives[:3], 1): |
|
st.write(f"{i}. {alt}") |
|
|
|
display_analysis_results(main_text) |
|
else: |
|
st.error("Could not transcribe the audio. Please try again with clearer audio.") |
|
|
|
if os.path.exists(temp_audio_path): |
|
os.remove(temp_audio_path) |
|
|
|
with tab2: |
|
st.header("Record Your Voice") |
|
st.write("Use the recorder below to analyze your speech in real-time.") |
|
|
|
st.subheader("Browser-Based Recorder") |
|
st.write("Click the button below to start/stop recording.") |
|
|
|
audio_data = custom_audio_recorder() |
|
|
|
if audio_data: |
|
analyze_rec_button = st.button("Analyze Recording", key="analyze_rec") |
|
|
|
if analyze_rec_button: |
|
with st.spinner("Processing your recording..."): |
|
temp_audio_path = process_base64_audio(audio_data) |
|
|
|
if temp_audio_path: |
|
transcribed_text = transcribe_audio(temp_audio_path) |
|
|
|
if transcribed_text: |
|
display_analysis_results(transcribed_text) |
|
else: |
|
st.error("Could not transcribe the audio. Please try speaking more clearly.") |
|
|
|
if os.path.exists(temp_audio_path): |
|
os.remove(temp_audio_path) |
|
|
|
st.subheader("Manual Text Input") |
|
st.write("If recording doesn't work, you can type your text here:") |
|
|
|
manual_text = st.text_area("Enter text to analyze:", placeholder="Type what you want to analyze...") |
|
analyze_text_button = st.button("Analyze Text", key="analyze_manual") |
|
|
|
if analyze_text_button and manual_text: |
|
display_analysis_results(manual_text) |
|
|
|
show_model_info() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|