import os import streamlit as st import tempfile import torch import transformers from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer import plotly.express as px import logging import warnings import whisper import base64 import io import asyncio from concurrent.futures import ThreadPoolExecutor import streamlit.components.v1 as components # Try importing torchaudio, fallback to pydub try: import torchaudio USE_TORCHAUDIO = True except ImportError: from pydub import AudioSegment USE_TORCHAUDIO = False st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio") # Suppress warnings logging.getLogger("torch").setLevel(logging.ERROR) logging.getLogger("transformers").setLevel(logging.ERROR) warnings.filterwarnings("ignore") os.environ["TOKENIZERS_PARALLELISM"] = "false" # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.write(f"Using device: {device}") # Streamlit config st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis") st.title("πŸŽ™ Voice Sentiment Analysis") st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from voice or text.") # Global model cache @st.cache_resource def load_models(): try: whisper_model = whisper.load_model("base") emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_model = emotion_model.to(device).half() emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer, top_k=None, device=0 if torch.cuda.is_available() else -1) sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_model = sarcasm_model.to(device).half() sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer, device=0 if torch.cuda.is_available() else -1) return whisper_model, emotion_classifier, sarcasm_classifier except Exception as e: st.error(f"Failed to load models: {str(e)}") raise whisper_model, emotion_classifier, sarcasm_classifier = load_models() # Emotion detection async def perform_emotion_detection(text): if not text or len(text.strip()) < 3: return {}, "neutral", {}, "NEUTRAL" try: results = emotion_classifier(text)[0] emotions_dict = {r['label']: r['score'] for r in results} filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01} top_emotion = max(filtered_emotions, key=filtered_emotions.get) positive_emotions = ["joy"] negative_emotions = ["anger", "disgust", "fear", "sadness"] sentiment = ("POSITIVE" if top_emotion in positive_emotions else "NEGATIVE" if top_emotion in negative_emotions else "NEUTRAL") emotion_map = {"joy": "😊", "anger": "😑", "disgust": "🀒", "fear": "😨", "sadness": "😭", "surprise": "😲"} return emotions_dict, top_emotion, emotion_map, sentiment except Exception as e: st.error(f"Emotion detection failed: {str(e)}") return {}, "neutral", {}, "NEUTRAL" # Sarcasm detection async def perform_sarcasm_detection(text): if not text or len(text.strip()) < 3: return False, 0.0 try: result = sarcasm_classifier(text)[0] is_sarcastic = result['label'] == "LABEL_1" sarcasm_score = result['score'] if is_sarcastic else 1 - result['score'] return is_sarcastic, sarcasm_score except Exception as e: st.error(f"Sarcasm detection failed: {str(e)}") return False, 0.0 # Audio validation def validate_audio(audio_path): try: if USE_TORCHAUDIO: waveform, sample_rate = torchaudio.load(audio_path) if waveform.abs().max() < 0.01: st.warning("Audio volume too low.") return False if waveform.shape[1] / sample_rate < 1: st.warning("Audio too short.") return False else: sound = AudioSegment.from_file(audio_path) if sound.dBFS < -55: st.warning("Audio volume too low.") return False if len(sound) < 1000: st.warning("Audio too short.") return False return True except Exception as e: st.error(f"Invalid audio file: {str(e)}") return False # Audio transcription @st.cache_data def transcribe_audio(audio_path): try: if USE_TORCHAUDIO: waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveformვ: waveform = resampler(waveform) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: torchaudio.save(temp_file.name, waveform, 16000) result = whisper_model.transcribe(temp_file.name, language="en") else: sound = AudioSegment.from_file(audio_path) sound = sound.set_frame_rate(16000).set_channels(1) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: sound.export(temp_file.name, format="wav") result = whisper_model.transcribe(temp_file.name, language="en") os.remove(temp_file.name) return result["text"].strip() except Exception as e: st.error(f"Transcription failed: {str(e)}") return "" # Process uploaded audio def process_uploaded_audio(audio_file): try: ext = audio_file.name.split('.')[-1].lower() if ext not in ['wav', 'mp3', 'ogg']: st.error("Unsupported format. Use WAV, MP3, or OGG.") return None with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as temp_file: temp_file.write(audio_file.getvalue()) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio: {str(e)}") return None # Process base64 audio def process_base64_audio(base64_data): try: base64_binary = base64_data.split(',')[1] binary_data = base64.b64decode(base64_binary) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_file.write(binary_data) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio data: {str(e)}") return None # Custom audio recorder def custom_audio_recorder(): audio_recorder_html = """
""" return components.html(audio_recorder_html, height=150) # Display results def display_analysis_results(transcribed_text): async def run_analyses(): emotion_task = perform_emotion_detection(transcribed_text) sarcasm_task = perform_sarcasm_detection(transcribed_text) return await asyncio.gather(emotion_task, sarcasm_task) with st.spinner("Analyzing..."): with ThreadPoolExecutor() as executor: loop = asyncio.get_event_loop() (emotions_dict, top_emotion, emotion_map, sentiment), (is_sarcastic, sarcasm_score) = loop.run_until_complete(run_analyses()) st.header("Results") st.subheader("Transcribed Text") st.text_area("Text", transcribed_text, height=100, disabled=True) col1, col2 = st.columns([1, 2]) with col1: st.subheader("Sentiment") sentiment_icon = "πŸ‘" if sentiment == "POSITIVE" else "πŸ‘Ž" if sentiment == "NEGATIVE" else "😐" st.markdown(f"{sentiment_icon} **{sentiment}**") st.subheader("Sarcasm") sarcasm_icon = "😏" if is_sarcastic else "😐" st.markdown(f"{sarcasm_icon} **{'Detected' if is_sarcastic else 'Not Detected'}** (Score: {sarcasm_score:.2f})") with col2: st.subheader("Emotions") if emotions_dict: st.markdown(f"*Dominant:* {emotion_map.get(top_emotion, '❓')} **{top_emotion.capitalize()}** ({emotions_dict[top_emotion]:.2f})") emotions = list(emotions_dict.keys())[:5] scores = list(emotions_dict.values())[:5] fig = px.bar(x=emotions, y=scores, labels={'x': 'Emotion', 'y': 'Score'}, color=emotions, color_discrete_sequence=px.colors.qualitative.Set2) fig.update_layout(yaxis_range=[0, 1], showlegend=False, height=300) st.plotly_chart(fig, use_container_width=True) else: st.write("No emotions detected.") with st.expander("Details"): st.markdown(""" - **Speech**: Whisper-base (fast, ~10-15% WER) - **Emotions**: DistilBERT (joy, anger, etc.) - **Sarcasm**: RoBERTa (irony detection) - **Tips**: Clear audio, minimal noise """) # Main app def main(): if 'debug_info' not in st.session_state: st.session_state.debug_info = [] tab1, tab2, tab3 = st.tabs(["πŸ“ Upload Audio", "πŸŽ™ Record Audio", "✍️ Text Input"]) with tab1: audio_file = st.file_uploader("Upload audio", type=["wav", "mp3", "ogg"]) if audio_file: st.audio(audio_file.getvalue()) if st.button("Analyze", key="upload_analyze"): progress = st.progress(0) temp_path = process_uploaded_audio(audio_file) if temp_path: progress.progress(50) text = transcribe_audio(temp_path) if text: progress.progress(100) display_analysis_results(text) else: st.error("Transcription failed.") if os.path.exists(temp_path): os.remove(temp_path) progress.empty() with tab2: st.markdown("Record audio using your microphone.") audio_data = custom_audio_recorder() if audio_data and st.button("Analyze", key="record_analyze"): progress = st.progress(0) temp_path = process_base64_audio(audio_data) if temp_path: progress.progress(50) text = transcribe_audio(temp_path) if text: progress.progress(100) display_analysis_results(text) else: st.error("Transcription failed.") if os.path.exists(temp_path): os.remove(temp_path) progress.empty() with tab3: manual_text = st.text_area("Enter text:", placeholder="Type text to analyze...") if st.button("Analyze", key="text_analyze") and manual_text: display_analysis_results(manual_text) if __name__ == "__main__": main() torch.cuda.empty_cache()