MaroofTechSorcerer's picture
Update app.py
42d828e verified
raw
history blame
12.6 kB
import os
import streamlit as st
import tempfile
import torch
import torchaudio
import transformers
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import plotly.express as px
import logging
import warnings
import whisper
import base64
import io
import asyncio
from concurrent.futures import ThreadPoolExecutor
import streamlit.components.v1 as components
# Suppress warnings
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.write(f"Using device: {device}")
# Streamlit config
st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis")
st.title("πŸŽ™ Voice Sentiment Analysis")
st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from voice or text.")
# Global model cache
@st.cache_resource
def load_models():
whisper_model = whisper.load_model("base")
emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
emotion_model = emotion_model.to(device).half()
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
top_k=None, device=0 if torch.cuda.is_available() else -1)
sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
sarcasm_model = sarcasm_model.to(device).half()
sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
device=0 if torch.cuda.is_available() else -1)
return whisper_model, emotion_classifier, sarcasm_classifier
whisper_model, emotion_classifier, sarcasm_classifier = load_models()
# Emotion detection
async def perform_emotion_detection(text):
if not text or len(text.strip()) < 3:
return {}, "neutral", {}, "NEUTRAL"
try:
results = emotion_classifier(text)[0]
emotions_dict = {r['label']: r['score'] for r in results}
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
top_emotion = max(filtered_emotions, key=filtered_emotions.get)
positive_emotions = ["joy"]
negative_emotions = ["anger", "disgust", "fear", "sadness"]
sentiment = ("POSITIVE" if top_emotion in positive_emotions else
"NEGATIVE" if top_emotion in negative_emotions else "NEUTRAL")
emotion_map = {"joy": "😊", "anger": "😑", "disgust": "🀒", "fear": "😨", "sadness": "😭", "surprise": "😲"}
return emotions_dict, top_emotion, emotion_map, sentiment
except Exception as e:
st.error(f"Emotion detection failed: {str(e)}")
return {}, "neutral", {}, "NEUTRAL"
# Sarcasm detection
async def perform_sarcasm_detection(text):
if not text or len(text.strip()) < 3:
return False, 0.0
try:
result = sarcasm_classifier(text)[0]
is_sarcastic = result['label'] == "LABEL_1"
sarcasm_score = result['score'] if is_sarcastic else 1 - result['score']
return is_sarcastic, sarcasm_score
except Exception as e:
st.error(f"Sarcasm detection failed: {str(e)}")
return False, 0.0
# Audio validation
def validate_audio(audio_path):
try:
waveform, sample_rate = torchaudio.load(audio_path)
if waveform.abs().max() < 0.01:
st.warning("Audio volume too low.")
return False
if waveform.shape[1] / sample_rate < 1:
st.warning("Audio too short.")
return False
return True
except:
st.error("Invalid audio file.")
return False
# Audio transcription
@st.cache_data
def transcribe_audio(audio_path):
try:
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
torchaudio.save(temp_file.name, waveform, 16000)
result = whisper_model.transcribe(temp_file.name, language="en")
os.remove(temp_file.name)
return result["text"].strip()
except Exception as e:
st.error(f"Transcription failed: {str(e)}")
return ""
# Process uploaded audio
def process_uploaded_audio(audio_file):
try:
ext = audio_file.name.split('.')[-1].lower()
if ext not in ['wav', 'mp3', 'ogg']:
st.error("Unsupported format. Use WAV, MP3, or OGG.")
return None
with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as temp_file:
temp_file.write(audio_file.getvalue())
temp_file_path = temp_file.name
if not validate_audio(temp_file_path):
os.remove(temp_file_path)
return None
return temp_file_path
except Exception as e:
st.error(f"Error processing audio: {str(e)}")
return None
# Process base64 audio
def process_base64_audio(base64_data):
try:
base64_binary = base64_data.split(',')[1]
binary_data = base64.b64decode(base64_binary)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_file.write(binary_data)
temp_file_path = temp_file.name
if not validate_audio(temp_file_path):
os.remove(temp_file_path)
return None
return temp_file_path
except Exception as e:
st.error(f"Error processing audio data: {str(e)}")
return None
# Custom audio recorder
def custom_audio_recorder():
audio_recorder_html = """
<script>
let recorder, audioBlob, isRecording = false;
const recordButton = document.getElementById('record-button');
const audioPlayback = document.getElementById('audio-playback');
const audioData = document.getElementById('audio-data');
async function startRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
recorder = new MediaRecorder(stream);
const chunks = [];
recorder.ondataavailable = e => chunks.push(e.data);
recorder.onstop = () => {
audioBlob = new Blob(chunks, { type: 'audio/wav' });
audioPlayback.src = URL.createObjectURL(audioBlob);
const reader = new FileReader();
reader.readAsDataURL(audioBlob);
reader.onloadend = () => {
audioData.value = reader.result;
window.parent.postMessage({type: "streamlit:setComponentValue", value: reader.result}, "*");
};
stream.getTracks().forEach(track => track.stop());
};
recorder.start();
isRecording = true;
recordButton.textContent = 'Stop Recording';
recordButton.classList.add('recording');
} catch (e) {
alert('Recording failed: ' + e.message);
}
}
function stopRecording() {
recorder.stop();
isRecording = false;
recordButton.textContent = 'Start Recording';
recordButton.classList.remove('recording');
}
document.getElementById('record-button').onclick = () => {
isRecording ? stopRecording() : startRecording();
};
</script>
<style>
.recorder-container { text-align: center; padding: 15px; }
.record-button { background: #ff4b4b; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer; }
.record-button.recording { background: #d32f2f; animation: pulse 1.5s infinite; }
@keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.7; } 100% { opacity: 1; } }
audio { margin-top: 10px; width: 100%; }
</style>
<div class="recorder-container">
<button id="record-button">Start Recording</button>
<audio id="audio-playback" controls></audio>
<input type="hidden" id="audio-data">
</div>
"""
return components.html(audio_recorder_html, height=150)
# Display results
def display_analysis_results(transcribed_text):
async def run_analyses():
emotion_task = perform_emotion_detection(transcribed_text)
sarcasm_task = perform_sarcasm_detection(transcribed_text)
return await asyncio.gather(emotion_task, sarcasm_task)
with st.spinner("Analyzing..."):
with ThreadPoolExecutor() as executor:
loop = asyncio.get_event_loop()
(emotions_dict, top_emotion, emotion_map, sentiment), (is_sarcastic, sarcasm_score) = loop.run_until_complete(run_analyses())
st.header("Results")
st.subheader("Transcribed Text")
st.text_area("Text", transcribed_text, height=100, disabled=True)
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("Sentiment")
sentiment_icon = "πŸ‘" if sentiment == "POSITIVE" else "πŸ‘Ž" if sentiment == "NEGATIVE" else "😐"
st.markdown(f"{sentiment_icon} **{sentiment}**")
st.subheader("Sarcasm")
sarcasm_icon = "😏" if is_sarcastic else "😐"
st.markdown(f"{sarcasm_icon} **{'Detected' if is_sarcastic else 'Not Detected'}** (Score: {sarcasm_score:.2f})")
with col2:
st.subheader("Emotions")
if emotions_dict:
st.markdown(f"*Dominant:* {emotion_map.get(top_emotion, '❓')} **{top_emotion.capitalize()}** ({emotions_dict[top_emotion]:.2f})")
emotions = list(emotions_dict.keys())[:5]
scores = list(emotions_dict.values())[:5]
fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'}, color=emotions,
color_discrete_sequence=px.colors.qualitative.Set2)
fig.update_layout(yaxis_range=[0, 1], showlegend=False, height=300)
st.plotly_chart(fig, use_container_width=True)
else:
st.write("No emotions detected.")
with st.expander("Details"):
st.markdown("""
- **Speech**: Whisper-base (fast, ~10-15% WER)
- **Emotions**: DistilBERT (joy, anger, etc.)
- **Sarcasm**: RoBERTa (irony detection)
- **Tips**: Clear audio, minimal noise
""")
# Main app
def main():
if 'debug_info' not in st.session_state:
st.session_state.debug_info = []
tab1, tab2, tab3 = st.tabs(["πŸ“ Upload Audio", "πŸŽ™ Record Audio", "✍️ Text Input"])
with tab1:
audio_file = st.file_uploader("Upload audio", type=["wav", "mp3", "ogg"])
if audio_file:
st.audio(audio_file.getvalue())
if st.button("Analyze", key="upload_analyze"):
progress = st.progress(0)
temp_path = process_uploaded_audio(audio_file)
if temp_path:
progress.progress(50)
text = transcribe_audio(temp_path)
if text:
progress.progress(100)
display_analysis_results(text)
else:
st.error("Transcription failed.")
os.remove(temp_path)
progress.empty()
with tab2:
st.markdown("Record audio using your microphone.")
audio_data = custom_audio_recorder()
if audio_data and st.button("Analyze", key="record_analyze"):
progress = st.progress(0)
temp_path = process_base64_audio(audio_data)
if temp_path:
progress.progress(50)
text = transcribe_audio(temp_path)
if text:
progress.progress(100)
display_analysis_results(text)
else:
st.error("Transcription failed.")
os.remove(temp_path)
progress.empty()
with tab3:
manual_text = st.text_area("Enter text:", placeholder="Type text to analyze...")
if st.button("Analyze", key="text_analyze") and manual_text:
display_analysis_results(manual_text)
if __name__ == "__main__":
main()
torch.cuda.empty_cache()