|
import os |
|
import streamlit as st |
|
import tempfile |
|
import torch |
|
import transformers |
|
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer |
|
import plotly.express as px |
|
import logging |
|
import warnings |
|
import whisper |
|
from pydub import AudioSegment |
|
import time |
|
import base64 |
|
import io |
|
import streamlit.components.v1 as components |
|
import numpy as np |
|
|
|
|
|
logging.getLogger("torch").setLevel(logging.CRITICAL) |
|
logging.getLogger("transformers").setLevel(logging.CRITICAL) |
|
warnings.filterwarnings("ignore") |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
try: |
|
test_array = np.array([1, 2, 3]) |
|
torch.from_numpy(test_array) |
|
except Exception as e: |
|
st.error(f"NumPy is not available or incompatible with PyTorch: {str(e)}. Ensure 'numpy' is in requirements.txt and reinstall dependencies.") |
|
st.stop() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
st.set_page_config(layout="wide", page_title="Voice Based Sentiment Analysis") |
|
|
|
|
|
st.title("π Voice Based Sentiment Analysis") |
|
st.write("Detect emotions, sentiment, and sarcasm from your voice with optimized speed and accuracy using OpenAI Whisper.") |
|
|
|
|
|
@st.cache_resource |
|
def get_emotion_classifier(): |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion", use_fast=True) |
|
model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion").to(device) |
|
if torch.cuda.is_available(): |
|
model = model.half() |
|
classifier = pipeline("text-classification", |
|
model=model, |
|
tokenizer=tokenizer, |
|
top_k=None, |
|
device=0 if torch.cuda.is_available() else -1) |
|
return classifier |
|
except Exception as e: |
|
st.error(f"Failed to load emotion model: {str(e)}") |
|
return None |
|
|
|
def perform_emotion_detection(text): |
|
try: |
|
if not text or len(text.strip()) < 3: |
|
return {}, "neutral", {}, "NEUTRAL" |
|
emotion_classifier = get_emotion_classifier() |
|
if not emotion_classifier: |
|
return {}, "neutral", {}, "NEUTRAL" |
|
emotion_results = emotion_classifier(text)[0] |
|
emotion_map = { |
|
"joy": "π", "anger": "π‘", "disgust": "π€’", "fear": "π¨", |
|
"sadness": "π", "surprise": "π²" |
|
} |
|
positive_emotions = ["joy"] |
|
negative_emotions = ["anger", "disgust", "fear", "sadness"] |
|
neutral_emotions = ["surprise"] |
|
emotions_dict = {result['label']: result['score'] for result in emotion_results} |
|
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01} |
|
if not filtered_emotions: |
|
filtered_emotions = emotions_dict |
|
top_emotion = max(filtered_emotions, key=filtered_emotions.get) |
|
if top_emotion in positive_emotions: |
|
sentiment = "POSITIVE" |
|
elif top_emotion in negative_emotions: |
|
sentiment = "NEGATIVE" |
|
else: |
|
sentiment = "NEUTRAL" |
|
return emotions_dict, top_emotion, emotion_map, sentiment |
|
except Exception as e: |
|
st.error(f"Emotion detection failed: {str(e)}") |
|
return {}, "neutral", {}, "NEUTRAL" |
|
|
|
|
|
@st.cache_resource |
|
def get_sarcasm_classifier(): |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony", use_fast=True) |
|
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony").to(device) |
|
if torch.cuda.is_available(): |
|
model = model.half() |
|
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1) |
|
return classifier |
|
except Exception as e: |
|
st.error(f"Failed to load sarcasm model: {str(e)}") |
|
return None |
|
|
|
def perform_sarcasm_detection(text): |
|
try: |
|
if not text or len(text.strip()) < 3: |
|
return False, 0.0 |
|
sarcasm_classifier = get_sarcasm_classifier() |
|
if not sarcasm_classifier: |
|
return False, 0.0 |
|
result = sarcasm_classifier(text)[0] |
|
is_sarcastic = result['label'] == "LABEL_1" |
|
sarcasm_score = result['score'] if is_sarcastic else 1 - result['score'] |
|
return is_sarcastic, sarcasm_score |
|
except Exception as e: |
|
st.error(f"Sarcasm detection failed: {str(e)}") |
|
return False, 0.0 |
|
|
|
|
|
def validate_audio(audio_path): |
|
try: |
|
sound = AudioSegment.from_file(audio_path) |
|
if sound.dBFS < -55: |
|
st.warning("Audio volume is too low.") |
|
return False |
|
if len(sound) < 1000: |
|
st.warning("Audio is too short.") |
|
return False |
|
return True |
|
except Exception as e: |
|
st.error(f"Invalid audio file: {str(e)}") |
|
return False |
|
|
|
|
|
@st.cache_resource |
|
def load_whisper_model(): |
|
try: |
|
model = whisper.load_model("base").to(device) |
|
return model |
|
except Exception as e: |
|
st.error(f"Failed to load Whisper model: {str(e)}") |
|
return None |
|
|
|
def transcribe_audio(audio_path): |
|
temp_wav_path = None |
|
try: |
|
sound = AudioSegment.from_file(audio_path).set_frame_rate(16000).set_channels(1) |
|
temp_wav_path = os.path.join(tempfile.gettempdir(), f"temp_{int(time.time())}.wav") |
|
sound.export(temp_wav_path, format="wav") |
|
model = load_whisper_model() |
|
if not model: |
|
return "" |
|
result = model.transcribe(temp_wav_path, language="en", fp16=torch.cuda.is_available()) |
|
return result["text"].strip() |
|
except Exception as e: |
|
st.error(f"Transcription failed: {str(e)}") |
|
return "" |
|
finally: |
|
if temp_wav_path and os.path.exists(temp_wav_path): |
|
os.remove(temp_wav_path) |
|
|
|
|
|
def process_uploaded_audio(audio_file): |
|
if not audio_file: |
|
return None |
|
temp_file_path = None |
|
try: |
|
ext = audio_file.name.split('.')[-1].lower() |
|
if ext not in ['wav', 'mp3', 'ogg']: |
|
st.error("Unsupported audio format. Use WAV, MP3, or OGG.") |
|
return None |
|
temp_file_path = os.path.join(tempfile.gettempdir(), f"uploaded_{int(time.time())}.{ext}") |
|
with open(temp_file_path, "wb") as f: |
|
f.write(audio_file.getvalue()) |
|
if not validate_audio(temp_file_path): |
|
return None |
|
return temp_file_path |
|
except Exception as e: |
|
st.error(f"Error processing uploaded audio: {str(e)}") |
|
return None |
|
finally: |
|
if temp_file_path and os.path.exists(temp_file_path): |
|
os.remove(temp_file_path) |
|
|
|
|
|
def show_model_info(): |
|
st.sidebar.header("π§ About the Models") |
|
with st.sidebar.expander("Model Details"): |
|
st.markdown(""" |
|
- *Emotion*: DistilBERT (bhadresh-savani/distilbert-base-uncased-emotion) |
|
- *Sarcasm*: RoBERTa (cardiffnlp/twitter-roberta-base-irony) |
|
- *Speech*: OpenAI Whisper (base) |
|
""") |
|
|
|
|
|
def custom_audio_recorder(): |
|
st.warning("Recording requires microphone access and a modern browser.") |
|
audio_recorder_html = """ |
|
<script> |
|
let recorder, stream; |
|
async function startRecording() { |
|
try { |
|
stream = await navigator.mediaDevices.getUserMedia({ audio: true }); |
|
recorder = new MediaRecorder(stream); |
|
const chunks = []; |
|
recorder.ondataavailable = e => chunks.push(e.data); |
|
recorder.onstop = () => { |
|
const blob = new Blob(chunks, { type: 'audio/wav' }); |
|
const reader = new FileReader(); |
|
reader.onloadend = () => { |
|
window.parent.postMessage({type: "streamlit:setComponentValue", value: reader.result}, "*"); |
|
}; |
|
reader.readAsDataURL(blob); |
|
stream.getTracks().forEach(track => track.stop()); |
|
}; |
|
recorder.start(); |
|
document.getElementById('record-btn').textContent = 'Stop Recording'; |
|
} catch (e) { alert('Recording failed: ' + e.message); } |
|
} |
|
function stopRecording() { |
|
recorder.stop(); |
|
document.getElementById('record-btn').textContent = 'Start Recording'; |
|
} |
|
function toggleRecording() { |
|
if (!recorder || recorder.state === 'inactive') startRecording(); |
|
else stopRecording(); |
|
} |
|
</script> |
|
<button id="record-btn" onclick="toggleRecording()">Start Recording</button> |
|
<style> |
|
#record-btn { |
|
background-color: #f63366; |
|
color: white; |
|
border: none; |
|
padding: 10px 20px; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
} |
|
#record-btn:hover { |
|
background-color: #ff0000; |
|
} |
|
</style> |
|
""" |
|
return components.html(audio_recorder_html, height=100) |
|
|
|
|
|
def display_analysis_results(transcribed_text): |
|
emotions_dict, top_emotion, emotion_map, sentiment = perform_emotion_detection(transcribed_text) |
|
is_sarcastic, sarcasm_score = perform_sarcasm_detection(transcribed_text) |
|
st.header("Analysis Results") |
|
st.text_area("Transcribed Text", transcribed_text, height=100, disabled=True) |
|
col1, col2 = st.columns([1, 2]) |
|
with col1: |
|
st.subheader("Sentiment") |
|
sentiment_icon = "π" if sentiment == "POSITIVE" else "π" if sentiment == "NEGATIVE" else "π" |
|
st.markdown(f"{sentiment_icon} {sentiment} (Based on {top_emotion})") |
|
st.subheader("Sarcasm") |
|
sarcasm_icon = "π" if is_sarcastic else "π" |
|
st.markdown(f"{sarcasm_icon} {'Detected' if is_sarcastic else 'Not Detected'} (Score: {sarcasm_score:.2f})") |
|
with col2: |
|
st.subheader("Emotions") |
|
if emotions_dict: |
|
st.markdown(f"*Dominant:* {emotion_map.get(top_emotion, 'β')} {top_emotion.capitalize()} (Score: {emotions_dict[top_emotion]:.3f})") |
|
fig = px.bar(x=list(emotions_dict.keys()), y=list(emotions_dict.values()), |
|
labels={'x': 'Emotion', 'y': 'Score'}, title="Emotion Distribution") |
|
st.plotly_chart(fig, use_container_width=True) |
|
else: |
|
st.write("No emotions detected.") |
|
|
|
|
|
def process_base64_audio(base64_data): |
|
temp_file_path = None |
|
try: |
|
audio_bytes = base64.b64decode(base64_data.split(',')[1]) |
|
temp_file_path = os.path.join(tempfile.gettempdir(), f"rec_{int(time.time())}.wav") |
|
with open(temp_file_path, "wb") as f: |
|
f.write(audio_bytes) |
|
if not validate_audio(temp_file_path): |
|
return None |
|
return temp_file_path |
|
except Exception as e: |
|
st.error(f"Error processing recorded audio: {str(e)}") |
|
return None |
|
finally: |
|
if temp_file_path and os.path.exists(temp_file_path): |
|
os.remove(temp_file_path) |
|
|
|
|
|
def main(): |
|
if 'debug_info' not in st.session_state: |
|
st.session_state.debug_info = [] |
|
tab1, tab2 = st.tabs(["π Upload Audio", "π Record Audio"]) |
|
with tab1: |
|
st.header("Upload an Audio File") |
|
audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"]) |
|
if audio_file: |
|
st.audio(audio_file.getvalue()) |
|
if st.button("Analyze Upload", key="analyze_upload"): |
|
with st.spinner("Analyzing audio..."): |
|
temp_audio_path = process_uploaded_audio(audio_file) |
|
if temp_audio_path: |
|
transcribed_text = transcribe_audio(temp_audio_path) |
|
if transcribed_text: |
|
display_analysis_results(transcribed_text) |
|
else: |
|
st.error("Could not transcribe audio. Try clearer audio.") |
|
with tab2: |
|
st.header("Record Your Voice") |
|
st.subheader("Browser-Based Recorder") |
|
audio_data = custom_audio_recorder() |
|
if audio_data and st.button("Analyze Recording", key="analyze_rec"): |
|
with st.spinner("Processing recording..."): |
|
temp_audio_path = process_base64_audio(audio_data) |
|
if temp_audio_path: |
|
transcribed_text = transcribe_audio(temp_audio_path) |
|
if transcribed_text: |
|
display_analysis_results(transcribed_text) |
|
else: |
|
st.error("Could not transcribe audio. Speak clearly.") |
|
st.subheader("Manual Text Input") |
|
manual_text = st.text_area("Enter text to analyze:", placeholder="Type your text...") |
|
if st.button("Analyze Text", key="analyze_manual") and manual_text: |
|
display_analysis_results(manual_text) |
|
show_model_info() |
|
|
|
if __name__ == "__main__": |
|
main() |