import os import streamlit as st import tempfile import torch import transformers from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer import plotly.express as px import logging import warnings import whisper import base64 import io import asyncio from concurrent.futures import ThreadPoolExecutor import streamlit.components.v1 as components # Try importing torchaudio, fallback to pydub try: import torchaudio USE_TORCHAUDIO = True except ImportError: from pydub import AudioSegment USE_TORCHAUDIO = False st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio") # Suppress warnings and set logging logging.getLogger("torch").setLevel(logging.ERROR) logging.getLogger("transformers").setLevel(logging.ERROR) warnings.filterwarnings("ignore") os.environ["TOKENIZERS_PARALLELISM"] = "false" # Streamlit config st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis") st.title("🎙 Voice Sentiment Analysis") st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from voice or text.") # Global model cache @st.cache_resource def load_models(): try: # Load Whisper model with CPU optimization whisper_model = whisper.load_model("base") # Load emotion detection model emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion") emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer, top_k=None, device=-1) # CPU only # Load sarcasm detection model sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony") sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer, device=-1) # CPU only return whisper_model, emotion_classifier, sarcasm_classifier except Exception as e: st.error(f"Failed to load models: {str(e)}") raise whisper_model, emotion_classifier, sarcasm_classifier = load_models() # Emotion detection async def perform_emotion_detection(text): if not text or len(text.strip()) < 3: return {}, "neutral", {}, "NEUTRAL" try: results = emotion_classifier(text)[0] emotions_dict = {r['label']: r['score'] for r in results} filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01} top_emotion = max(filtered_emotions, key=filtered_emotions.get, default="neutral") positive_emotions = ["joy"] negative_emotions = ["anger", "disgust", "fear", "sadness"] sentiment = ("POSITIVE" if top_emotion in positive_emotions else "NEGATIVE" if top_emotion in negative_emotions else "NEUTRAL") emotion_map = {"joy": "😊", "anger": "😡", "disgust": "🤢", "fear": "😨", "sadness": "😭", "surprise": "😲"} return emotions_dict, top_emotion, emotion_map, sentiment except Exception as e: st.error(f"Emotion detection failed: {str(e)}") return {}, "neutral", {}, "NEUTRAL" # Sarcasm detection async def perform_sarcasm_detection(text): if not text or len(text.strip()) < 3: return False, 0.0 try: result = sarcasm_classifier(text)[0] is_sarcastic = result['label'] == "LABEL_1" sarcasm_score = result['score'] if is_sarcastic else 1 - result['score'] return is_sarcastic, sarcasm_score except Exception as e: st.error(f"Sarcasm detection failed: {str(e)}") return False, 0.0 # Audio validation def validate_audio(audio_path): try: if USE_TORCHAUDIO: waveform, sample_rate = torchaudio.load(audio_path) if waveform.abs().max() < 0.01: st.warning("Audio volume too low.") return False if waveform.shape[1] / sample_rate < 1: st.warning("Audio too short.") return False else: sound = AudioSegment.from_file(audio_path) if sound.dBFS < -55: st.warning("Audio volume too low.") return False if len(sound) < 1000: st.warning("Audio too short.") return False return True except Exception as e: st.error(f"Invalid audio file: {str(e)}") return False # Audio transcription @st.cache_data def transcribe_audio(audio_path): try: if USE_TORCHAUDIO: waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: torchaudio.save(temp_file.name, waveform, 16000) result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6) else: sound = AudioSegment.from_file(audio_path) sound = sound.set_frame_rate(16000).set_channels(1) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: sound.export(temp_file.name, format="wav") result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6) os.remove(temp_file.name) return result["text"].strip() except Exception as e: st.error(f"Transcription failed: {str(e)}") return "" # Process uploaded audio def process_uploaded_audio(audio_file): try: ext = audio_file.name.split('.')[-1].lower() if ext not in ['wav', 'mp3', 'ogg']: st.error("Unsupported format. Use WAV, MP3, or OGG.") return None with tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False) as temp_file: temp_file.write(audio_file.getvalue()) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio: {str(e)}") return None # Process base64 audio def process_base64_audio(base64_data): try: if not base64_data.startswith("data:audio"): st.error("Invalid audio data.") return None base64_binary = base64_data.split(',')[1] binary_data = base64.b64decode(base64_binary) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_file.write(binary_data) temp_file_path = temp_file.name if not validate_audio(temp_file_path): os.remove(temp_file_path) return None return temp_file_path except Exception as e: st.error(f"Error processing audio data: {str(e)}") return None # Custom audio recorder def custom_audio_recorder(): audio_recorder_html = """