|
import os |
|
import streamlit as st |
|
import tempfile |
|
import torch |
|
import torchaudio |
|
import transformers |
|
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer |
|
import plotly.express as px |
|
import logging |
|
import warnings |
|
import whisper |
|
import base64 |
|
import io |
|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
import streamlit.components.v1 as components |
|
|
|
|
|
logging.getLogger("torch").setLevel(logging.ERROR) |
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
warnings.filterwarnings("ignore") |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
st.write(f"Using device: {device}") |
|
|
|
|
|
st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis") |
|
st.title("π Voice Sentiment Analysis") |
|
st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from voice or text.") |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
whisper_model = whisper.load_model("base") |
|
|
|
emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") |
|
emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") |
|
emotion_model = emotion_model.to(device).half() |
|
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer, |
|
top_k=None, device=0 if torch.cuda.is_available() else -1) |
|
|
|
sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony") |
|
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony") |
|
sarcasm_model = sarcasm_model.to(device).half() |
|
sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer, |
|
device=0 if torch.cuda.is_available() else -1) |
|
|
|
return whisper_model, emotion_classifier, sarcasm_classifier |
|
|
|
whisper_model, emotion_classifier, sarcasm_classifier = load_models() |
|
|
|
|
|
async def perform_emotion_detection(text): |
|
if not text or len(text.strip()) < 3: |
|
return {}, "neutral", {}, "NEUTRAL" |
|
|
|
try: |
|
results = emotion_classifier(text)[0] |
|
emotions_dict = {r['label']: r['score'] for r in results} |
|
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01} |
|
top_emotion = max(filtered_emotions, key=filtered_emotions.get) |
|
|
|
positive_emotions = ["joy"] |
|
negative_emotions = ["anger", "disgust", "fear", "sadness"] |
|
sentiment = ("POSITIVE" if top_emotion in positive_emotions else |
|
"NEGATIVE" if top_emotion in negative_emotions else "NEUTRAL") |
|
|
|
emotion_map = {"joy": "π", "anger": "π‘", "disgust": "π€’", "fear": "π¨", "sadness": "π", "surprise": "π²"} |
|
return emotions_dict, top_emotion, emotion_map, sentiment |
|
except Exception as e: |
|
st.error(f"Emotion detection failed: {str(e)}") |
|
return {}, "neutral", {}, "NEUTRAL" |
|
|
|
|
|
async def perform_sarcasm_detection(text): |
|
if not text or len(text.strip()) < 3: |
|
return False, 0.0 |
|
|
|
try: |
|
result = sarcasm_classifier(text)[0] |
|
is_sarcastic = result['label'] == "LABEL_1" |
|
sarcasm_score = result['score'] if is_sarcastic else 1 - result['score'] |
|
return is_sarcastic, sarcasm_score |
|
except Exception as e: |
|
st.error(f"Sarcasm detection failed: {str(e)}") |
|
return False, 0.0 |
|
|
|
|
|
def validate_audio(audio_path): |
|
try: |
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
if waveform.abs().max() < 0.01: |
|
st.warning("Audio volume too low.") |
|
return False |
|
if waveform.shape[1] / sample_rate < 1: |
|
st.warning("Audio too short.") |
|
return False |
|
return True |
|
except: |
|
st.error("Invalid audio file.") |
|
return False |
|
|
|
|
|
@st.cache_data |
|
def transcribe_audio(audio_path): |
|
try: |
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
|
waveform = resampler(waveform) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: |
|
torchaudio.save(temp_file.name, waveform, 16000) |
|
result = whisper_model.transcribe(temp_file.name, language="en") |
|
os.remove(temp_file.name) |
|
return result["text"].strip() |
|
except Exception as e: |
|
st.error(f"Transcription failed: {str(e)}") |
|
return "" |
|
|
|
|
|
def process_uploaded_audio(audio_file): |
|
try: |
|
ext = audio_file.name.split('.')[-1].lower() |
|
if ext not in ['wav', 'mp3', 'ogg']: |
|
st.error("Unsupported format. Use WAV, MP3, or OGG.") |
|
return None |
|
with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as temp_file: |
|
temp_file.write(audio_file.getvalue()) |
|
temp_file_path = temp_file.name |
|
if not validate_audio(temp_file_path): |
|
os.remove(temp_file_path) |
|
return None |
|
return temp_file_path |
|
except Exception as e: |
|
st.error(f"Error processing audio: {str(e)}") |
|
return None |
|
|
|
|
|
def process_base64_audio(base64_data): |
|
try: |
|
base64_binary = base64_data.split(',')[1] |
|
binary_data = base64.b64decode(base64_binary) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: |
|
temp_file.write(binary_data) |
|
temp_file_path = temp_file.name |
|
if not validate_audio(temp_file_path): |
|
os.remove(temp_file_path) |
|
return None |
|
return temp_file_path |
|
except Exception as e: |
|
st.error(f"Error processing audio data: {str(e)}") |
|
return None |
|
|
|
|
|
def custom_audio_recorder(): |
|
audio_recorder_html = """ |
|
<script> |
|
let recorder, audioBlob, isRecording = false; |
|
const recordButton = document.getElementById('record-button'); |
|
const audioPlayback = document.getElementById('audio-playback'); |
|
const audioData = document.getElementById('audio-data'); |
|
|
|
async function startRecording() { |
|
try { |
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); |
|
recorder = new MediaRecorder(stream); |
|
const chunks = []; |
|
recorder.ondataavailable = e => chunks.push(e.data); |
|
recorder.onstop = () => { |
|
audioBlob = new Blob(chunks, { type: 'audio/wav' }); |
|
audioPlayback.src = URL.createObjectURL(audioBlob); |
|
const reader = new FileReader(); |
|
reader.readAsDataURL(audioBlob); |
|
reader.onloadend = () => { |
|
audioData.value = reader.result; |
|
window.parent.postMessage({type: "streamlit:setComponentValue", value: reader.result}, "*"); |
|
}; |
|
stream.getTracks().forEach(track => track.stop()); |
|
}; |
|
recorder.start(); |
|
isRecording = true; |
|
recordButton.textContent = 'Stop Recording'; |
|
recordButton.classList.add('recording'); |
|
} catch (e) { |
|
alert('Recording failed: ' + e.message); |
|
} |
|
} |
|
|
|
function stopRecording() { |
|
recorder.stop(); |
|
isRecording = false; |
|
recordButton.textContent = 'Start Recording'; |
|
recordButton.classList.remove('recording'); |
|
} |
|
|
|
document.getElementById('record-button').onclick = () => { |
|
isRecording ? stopRecording() : startRecording(); |
|
}; |
|
</script> |
|
<style> |
|
.recorder-container { text-align: center; padding: 15px; } |
|
.record-button { background: #ff4b4b; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer; } |
|
.record-button.recording { background: #d32f2f; animation: pulse 1.5s infinite; } |
|
@keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.7; } 100% { opacity: 1; } } |
|
audio { margin-top: 10px; width: 100%; } |
|
</style> |
|
<div class="recorder-container"> |
|
<button id="record-button">Start Recording</button> |
|
<audio id="audio-playback" controls></audio> |
|
<input type="hidden" id="audio-data"> |
|
</div> |
|
""" |
|
return components.html(audio_recorder_html, height=150) |
|
|
|
|
|
def display_analysis_results(transcribed_text): |
|
async def run_analyses(): |
|
emotion_task = perform_emotion_detection(transcribed_text) |
|
sarcasm_task = perform_sarcasm_detection(transcribed_text) |
|
return await asyncio.gather(emotion_task, sarcasm_task) |
|
|
|
with st.spinner("Analyzing..."): |
|
with ThreadPoolExecutor() as executor: |
|
loop = asyncio.get_event_loop() |
|
(emotions_dict, top_emotion, emotion_map, sentiment), (is_sarcastic, sarcasm_score) = loop.run_until_complete(run_analyses()) |
|
|
|
st.header("Results") |
|
st.subheader("Transcribed Text") |
|
st.text_area("Text", transcribed_text, height=100, disabled=True) |
|
|
|
col1, col2 = st.columns([1, 2]) |
|
with col1: |
|
st.subheader("Sentiment") |
|
sentiment_icon = "π" if sentiment == "POSITIVE" else "π" if sentiment == "NEGATIVE" else "π" |
|
st.markdown(f"{sentiment_icon} **{sentiment}**") |
|
|
|
st.subheader("Sarcasm") |
|
sarcasm_icon = "π" if is_sarcastic else "π" |
|
st.markdown(f"{sarcasm_icon} **{'Detected' if is_sarcastic else 'Not Detected'}** (Score: {sarcasm_score:.2f})") |
|
|
|
with col2: |
|
st.subheader("Emotions") |
|
if emotions_dict: |
|
st.markdown(f"*Dominant:* {emotion_map.get(top_emotion, 'β')} **{top_emotion.capitalize()}** ({emotions_dict[top_emotion]:.2f})") |
|
emotions = list(emotions_dict.keys())[:5] |
|
scores = list(emotions_dict.values())[:5] |
|
fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'}, color=emotions, |
|
color_discrete_sequence=px.colors.qualitative.Set2) |
|
fig.update_layout(yaxis_range=[0, 1], showlegend=False, height=300) |
|
st.plotly_chart(fig, use_container_width=True) |
|
else: |
|
st.write("No emotions detected.") |
|
|
|
with st.expander("Details"): |
|
st.markdown(""" |
|
- **Speech**: Whisper-base (fast, ~10-15% WER) |
|
- **Emotions**: DistilBERT (joy, anger, etc.) |
|
- **Sarcasm**: RoBERTa (irony detection) |
|
- **Tips**: Clear audio, minimal noise |
|
""") |
|
|
|
|
|
def main(): |
|
if 'debug_info' not in st.session_state: |
|
st.session_state.debug_info = [] |
|
|
|
tab1, tab2, tab3 = st.tabs(["π Upload Audio", "π Record Audio", "βοΈ Text Input"]) |
|
|
|
with tab1: |
|
audio_file = st.file_uploader("Upload audio", type=["wav", "mp3", "ogg"]) |
|
if audio_file: |
|
st.audio(audio_file.getvalue()) |
|
if st.button("Analyze", key="upload_analyze"): |
|
progress = st.progress(0) |
|
temp_path = process_uploaded_audio(audio_file) |
|
if temp_path: |
|
progress.progress(50) |
|
text = transcribe_audio(temp_path) |
|
if text: |
|
progress.progress(100) |
|
display_analysis_results(text) |
|
else: |
|
st.error("Transcription failed.") |
|
os.remove(temp_path) |
|
progress.empty() |
|
|
|
with tab2: |
|
st.markdown("Record audio using your microphone.") |
|
audio_data = custom_audio_recorder() |
|
if audio_data and st.button("Analyze", key="record_analyze"): |
|
progress = st.progress(0) |
|
temp_path = process_base64_audio(audio_data) |
|
if temp_path: |
|
progress.progress(50) |
|
text = transcribe_audio(temp_path) |
|
if text: |
|
progress.progress(100) |
|
display_analysis_results(text) |
|
else: |
|
st.error("Transcription failed.") |
|
os.remove(temp_path) |
|
progress.empty() |
|
|
|
with tab3: |
|
manual_text = st.text_area("Enter text:", placeholder="Type text to analyze...") |
|
if st.button("Analyze", key="text_analyze") and manual_text: |
|
display_analysis_results(manual_text) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
torch.cuda.empty_cache() |