import os import streamlit as st import tempfile import torch import transformers from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer import plotly.express as px import logging import warnings import whisper import base64 import io import asyncio from concurrent.futures import ThreadPoolExecutor import streamlit.components.v1 as components # Try importing torchaudio, fallback to pydub try: import torchaudio USE_TORCHAUDIO = True except ImportError: from pydub import AudioSegment USE_TORCHAUDIO = False st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio") # Suppress warnings and set logging logging.getLogger("torch").setLevel(logging.ERROR) logging.getLogger("transformers").setLevel(logging.ERROR) warnings.filterwarnings("ignore") os.environ["TOKENIZERS_PARALLELISM"] = "false" # Streamlit config st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis") st.title("🎙 Voice Sentiment Analysis") st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from voice or text.") # Global model cache @st.cache_resource def load_models(): try: # Load Whisper model with CPU optimization whisper_model = whisper.load_model("base") # Load emotion detection model emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer, top_k=None, device=-1) # CPU only # Load sarcasm detection model sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer, device=-1) # CPU only return whisper_model, emotion_classifier, sarcasm_classifier except Exception as e: st.error(f"Failed to load models: {str(e)}") raise whisper_model, emotion_classifier, sarcasm_classifier = load_models() # Emotion detection async def perform_emotion_detection(text): if not text or len(text.strip()) < 3: return {}, "neutral", {}, "NEUTRAL" try: results = emotion_classifier(text)[0] emotions_dict = {r['label']: r['score'] for r in results} filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01} top_emotion = max(filtered_emotions, key=filtered_emotions.get, default="neutral") positive_emotions = ["joy"] negative_emotions = ["anger", "disgust", "fear", "sadness"] sentiment = ("POSITIVE" if top_emotion in positive_emotions else "NEGATIVE" if top_emotion in negative_emotions else "NEUTRAL") emotion_map = {"joy": "😊", "anger": "😡", "disgust": "🤢", "fear": "😨", "sadness": "😭", "surprise": "😲"} return emotions_dict, top_emotion, emotion_map, sentiment except Exception as e: st.error(f"Emotion detection failed: {str(e)}") return {}, "neutral", {}, "NEUTRAL" # Sarcasm detection async def perform_sarcasm_detection(text): if not text or len(text.strip()) < 3: return False, 0.0 try: result = sarcasm_classifier(text)[0] is_sarcastic = result['label'] == "LABEL_1" sarcasm_score = result['score'] if is_sarcastic else 1 - result['score'] return is_sarcastic, sarcasm_score except Exception as e: st.error(f"Sarcasm detection failed: {str(e)}") return False, 0.0 # Audio validation def validate_audio(audio_path): try: if USE_TORCHAUDIO: waveform, sample_rate = torchaudio.load(audio_path) if waveform.abs().max() < 0.01: st.warning("Audio volume too low.") return False if waveform.shape[1] / sample_rate < 1: st.warning("Audio too short.") return False else: sound = AudioSegment.from_file(audio_path) if sound.dBFS < -55: st.warning("Audio volume too low.") return False if len(sound) < 1000: st.warning("Audio too short.") return False return True except Exception as e: st.error(f"Invalid audio file: {str(e)}") return False # Audio transcription @st.cache_data def transcribe_audio(audio_path): try: if USE_TORCHAUDIO: waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: torchaudio.save(temp_file.name, waveform, 16000) result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6) else: sound = AudioSegment.from_file(audio_path) sound = sound.set_frame_rate(16000).set_channels(1) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: sound.export(temp_file.name, format="wav") result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6) os.remove(temp_file.name) return result["text"].strip() except Exception as e: st.error(f"Transcription failed: {str(e)}") return "" # Process uploaded audio def process_uploaded_audio(audio_file): try: ext = audio_file.name.split('.')[-1].lower() if ext not in ['wav', 'mp3', 'ogg']: st.error("Unsupported format. Use WAV, MP3, or OGG.") return None with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as temp_file: temp_file.write(audio_file.getvalue()) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio: {str(e)}") return None # Process base64 audio def process_base64_audio(base64_data): try: if not base64_data.startswith("data:audio"): st.error("Invalid audio data.") return None base64_binary = base64_data.split(',')[1] binary_data = base64.b64decode(base64_binary) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_file.write(binary_data) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio data: {str(e)}") return None # Custom audio recorder def custom_audio_recorder(): audio_recorder_html = """
""" return components.html(audio_recorder_html, height=150) # Display results def display_analysis_results(transcribed_text): async def run_analyses(): emotion_task = perform_emotion_detection(transcribed_text) sarcasm_task = perform_sarcasm_detection(transcribed_text) return await asyncio.gather(emotion_task, sarcasm_task) with st.spinner("Analyzing..."): with ThreadPoolExecutor() as executor: loop = asyncio.get_event_loop() (emotions_dict, top_emotion, emotion_map, sentiment), (is_sarcastic, sarcasm_score) = loop.run_until_complete(run_analyses()) st.header("Results") st.subheader("Transcribed Text") st.text_area("Text", transcribed_text, height=100, disabled=True) col1, col2 = st.columns([1, 2]) with col1: st.subheader("Sentiment") sentiment_icon = "👍" if sentiment == "POSITIVE" else "👎" if sentiment == "NEGATIVE" else "😐" st.markdown(f"{sentiment_icon} **{sentiment}**") st.subheader("Sarcasm") sarcasm_icon = "😏" if is_sarcastic else "😐" st.markdown(f"{sarcasm_icon} **{'Detected' if is_sarcastic else 'Not Detected'}** (Score: {sarcasm_score:.2f})") with col2: st.subheader("Emotions") if emotions_dict: st.markdown(f"*Dominant:* {emotion_map.get(top_emotion, '❓')} **{top_emotion.capitalize()}** ({emotions_dict[top_emotion]:.2f})") emotions = list(emotions_dict.keys())[:5] scores = list(emotions_dict.values())[:5] fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'}, color=emotions, color_discrete_sequence=px.colors.qualitative.Set2) fig.update_layout(yaxis_range=[0, 1], showlegend=False, height=300) st.plotly_chart(fig, use_container_width=True) else: st.write("No emotions detected.") with st.expander("Details"): st.markdown(""" - **Speech**: Whisper-base (fast, ~10-15% WER) - **Emotions**: DistilBERT (joy, anger, etc.) - **Sarcasm**: RoBERTa (irony detection) - **Tips**: Clear audio, minimal noise """) # Main app def main(): if 'debug_info' not in st.session_state: st.session_state.debug_info = [] tab1, tab2, tab3 = st.tabs(["📁 Upload Audio", "🎙 Record Audio", "✍️ Text Input"]) with tab1: audio_file = st.file_uploader("Upload audio", type=["wav", "mp3", "ogg"]) if audio_file: st.audio(audio_file.getvalue()) if st.button("Analyze", key="upload_analyze"): progress = st.progress(0) temp_path = process_uploaded_audio(audio_file) if temp_path: progress.progress(50) text = transcribe_audio(temp_path) if text: progress.progress(100) display_analysis_results(text) else: st.error("Transcription failed.") if os.path.exists(temp_path): os.remove(temp_path) progress.empty() with tab2: st.markdown("Record audio using your microphone.") audio_data = custom_audio_recorder() if audio_data and st.button("Analyze", key="record_analyze"): progress = st.progress(0) temp_path = process_base64_audio(audio_data) if temp_path: progress.progress(50) text = transcribe_audio(temp_path) if text: progress.progress(100) display_analysis_results(text) else: st.error("Transcription failed.") if os.path.exists(temp_path): os.remove(temp_path) progress.empty() with tab3: manual_text = st.text_area("Enter text:", placeholder="Type text to analyze...") if st.button("Analyze", key="text_analyze") and manual_text: display_analysis_results(manual_text) if __name__ == "__main__": main()