MaroofTechSorcerer's picture
Update app.py
1cec378 verified
raw
history blame
13.3 kB
import os
import streamlit as st
import tempfile
import torch
import transformers
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import plotly.express as px
import logging
import warnings
import whisper
from pydub import AudioSegment
import time
import base64
import io
import streamlit.components.v1 as components
import numpy as np
# Suppress warnings for a clean console
logging.getLogger("torch").setLevel(logging.CRITICAL)
logging.getLogger("transformers").setLevel(logging.CRITICAL)
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Check if NumPy is available
try:
test_array = np.array([1, 2, 3])
torch.from_numpy(test_array)
except Exception as e:
st.error(f"NumPy is not available or incompatible with PyTorch: {str(e)}. Ensure 'numpy' is in requirements.txt and reinstall dependencies.")
st.stop()
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Set Streamlit app layout
st.set_page_config(layout="wide", page_title="Voice Based Sentiment Analysis")
# Interface design
st.title("πŸŽ™ Voice Based Sentiment Analysis")
st.write("Detect emotions, sentiment, and sarcasm from your voice with optimized speed and accuracy using OpenAI Whisper.")
# Emotion Detection Function
@st.cache_resource
def get_emotion_classifier():
try:
tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion", use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion").to(device)
if torch.cuda.is_available():
model = model.half() # Use fp16 on GPU
classifier = pipeline("text-classification",
model=model,
tokenizer=tokenizer,
top_k=None,
device=0 if torch.cuda.is_available() else -1)
return classifier
except Exception as e:
st.error(f"Failed to load emotion model: {str(e)}")
return None
def perform_emotion_detection(text):
try:
if not text or len(text.strip()) < 3:
return {}, "neutral", {}, "NEUTRAL"
emotion_classifier = get_emotion_classifier()
if not emotion_classifier:
return {}, "neutral", {}, "NEUTRAL"
emotion_results = emotion_classifier(text)[0]
emotion_map = {
"joy": "😊", "anger": "😑", "disgust": "🀒", "fear": "😨",
"sadness": "😭", "surprise": "😲"
}
positive_emotions = ["joy"]
negative_emotions = ["anger", "disgust", "fear", "sadness"]
neutral_emotions = ["surprise"]
emotions_dict = {result['label']: result['score'] for result in emotion_results}
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
if not filtered_emotions:
filtered_emotions = emotions_dict
top_emotion = max(filtered_emotions, key=filtered_emotions.get)
if top_emotion in positive_emotions:
sentiment = "POSITIVE"
elif top_emotion in negative_emotions:
sentiment = "NEGATIVE"
else:
sentiment = "NEUTRAL"
return emotions_dict, top_emotion, emotion_map, sentiment
except Exception as e:
st.error(f"Emotion detection failed: {str(e)}")
return {}, "neutral", {}, "NEUTRAL"
# Sarcasm Detection Function
@st.cache_resource
def get_sarcasm_classifier():
try:
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony", use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony").to(device)
if torch.cuda.is_available():
model = model.half() # Use fp16 on GPU
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1)
return classifier
except Exception as e:
st.error(f"Failed to load sarcasm model: {str(e)}")
return None
def perform_sarcasm_detection(text):
try:
if not text or len(text.strip()) < 3:
return False, 0.0
sarcasm_classifier = get_sarcasm_classifier()
if not sarcasm_classifier:
return False, 0.0
result = sarcasm_classifier(text)[0]
is_sarcastic = result['label'] == "LABEL_1"
sarcasm_score = result['score'] if is_sarcastic else 1 - result['score']
return is_sarcastic, sarcasm_score
except Exception as e:
st.error(f"Sarcasm detection failed: {str(e)}")
return False, 0.0
# Validate audio quality
def validate_audio(audio_path):
try:
sound = AudioSegment.from_file(audio_path)
if sound.dBFS < -55:
st.warning("Audio volume is too low.")
return False
if len(sound) < 1000:
st.warning("Audio is too short.")
return False
return True
except Exception as e:
st.error(f"Invalid audio file: {str(e)}")
return False
# Speech Recognition with Whisper
@st.cache_resource
def load_whisper_model():
try:
model = whisper.load_model("base").to(device)
return model
except Exception as e:
st.error(f"Failed to load Whisper model: {str(e)}")
return None
def transcribe_audio(audio_path):
temp_wav_path = None
try:
sound = AudioSegment.from_file(audio_path).set_frame_rate(16000).set_channels(1)
temp_wav_path = os.path.join(tempfile.gettempdir(), f"temp_{int(time.time())}.wav")
sound.export(temp_wav_path, format="wav")
model = load_whisper_model()
if not model:
return ""
result = model.transcribe(temp_wav_path, language="en", fp16=torch.cuda.is_available())
return result["text"].strip()
except Exception as e:
st.error(f"Transcription failed: {str(e)}")
return ""
finally:
if temp_wav_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
# Process uploaded audio files
def process_uploaded_audio(audio_file):
if not audio_file:
return None
temp_file_path = None
try:
ext = audio_file.name.split('.')[-1].lower()
if ext not in ['wav', 'mp3', 'ogg']:
st.error("Unsupported audio format. Use WAV, MP3, or OGG.")
return None
temp_file_path = os.path.join(tempfile.gettempdir(), f"uploaded_{int(time.time())}.{ext}")
with open(temp_file_path, "wb") as f:
f.write(audio_file.getvalue())
if not validate_audio(temp_file_path):
return None
return temp_file_path
except Exception as e:
st.error(f"Error processing uploaded audio: {str(e)}")
return None
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
# Show model information
def show_model_info():
st.sidebar.header("🧠 About the Models")
with st.sidebar.expander("Model Details"):
st.markdown("""
- *Emotion*: DistilBERT (bhadresh-savani/distilbert-base-uncased-emotion)
- *Sarcasm*: RoBERTa (cardiffnlp/twitter-roberta-base-irony)
- *Speech*: OpenAI Whisper (base)
""")
# Custom audio recorder
def custom_audio_recorder():
st.warning("Recording requires microphone access and a modern browser.")
audio_recorder_html = """
<script>
let recorder, stream;
async function startRecording() {
try {
stream = await navigator.mediaDevices.getUserMedia({ audio: true });
recorder = new MediaRecorder(stream);
const chunks = [];
recorder.ondataavailable = e => chunks.push(e.data);
recorder.onstop = () => {
const blob = new Blob(chunks, { type: 'audio/wav' });
const reader = new FileReader();
reader.onloadend = () => {
window.parent.postMessage({type: "streamlit:setComponentValue", value: reader.result}, "*");
};
reader.readAsDataURL(blob);
stream.getTracks().forEach(track => track.stop());
};
recorder.start();
document.getElementById('record-btn').textContent = 'Stop Recording';
} catch (e) { alert('Recording failed: ' + e.message); }
}
function stopRecording() {
recorder.stop();
document.getElementById('record-btn').textContent = 'Start Recording';
}
function toggleRecording() {
if (!recorder || recorder.state === 'inactive') startRecording();
else stopRecording();
}
</script>
<button id="record-btn" onclick="toggleRecording()">Start Recording</button>
<style>
#record-btn {
background-color: #f63366;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
}
#record-btn:hover {
background-color: #ff0000;
}
</style>
"""
return components.html(audio_recorder_html, height=100)
# Display analysis results
def display_analysis_results(transcribed_text):
emotions_dict, top_emotion, emotion_map, sentiment = perform_emotion_detection(transcribed_text)
is_sarcastic, sarcasm_score = perform_sarcasm_detection(transcribed_text)
st.header("Analysis Results")
st.text_area("Transcribed Text", transcribed_text, height=100, disabled=True)
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("Sentiment")
sentiment_icon = "πŸ‘" if sentiment == "POSITIVE" else "πŸ‘Ž" if sentiment == "NEGATIVE" else "😐"
st.markdown(f"{sentiment_icon} {sentiment} (Based on {top_emotion})")
st.subheader("Sarcasm")
sarcasm_icon = "😏" if is_sarcastic else "😐"
st.markdown(f"{sarcasm_icon} {'Detected' if is_sarcastic else 'Not Detected'} (Score: {sarcasm_score:.2f})")
with col2:
st.subheader("Emotions")
if emotions_dict:
st.markdown(f"*Dominant:* {emotion_map.get(top_emotion, '❓')} {top_emotion.capitalize()} (Score: {emotions_dict[top_emotion]:.3f})")
fig = px.bar(x=list(emotions_dict.keys()), y=list(emotions_dict.values()),
labels={'x': 'Emotion', 'y': 'Score'}, title="Emotion Distribution")
st.plotly_chart(fig, use_container_width=True)
else:
st.write("No emotions detected.")
# Process base64 audio
def process_base64_audio(base64_data):
temp_file_path = None
try:
audio_bytes = base64.b64decode(base64_data.split(',')[1])
temp_file_path = os.path.join(tempfile.gettempdir(), f"rec_{int(time.time())}.wav")
with open(temp_file_path, "wb") as f:
f.write(audio_bytes)
if not validate_audio(temp_file_path):
return None
return temp_file_path
except Exception as e:
st.error(f"Error processing recorded audio: {str(e)}")
return None
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
# Main App Logic
def main():
if 'debug_info' not in st.session_state:
st.session_state.debug_info = []
tab1, tab2 = st.tabs(["πŸ“ Upload Audio", "πŸŽ™ Record Audio"])
with tab1:
st.header("Upload an Audio File")
audio_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
if audio_file:
st.audio(audio_file.getvalue())
if st.button("Analyze Upload", key="analyze_upload"):
with st.spinner("Analyzing audio..."):
temp_audio_path = process_uploaded_audio(audio_file)
if temp_audio_path:
transcribed_text = transcribe_audio(temp_audio_path)
if transcribed_text:
display_analysis_results(transcribed_text)
else:
st.error("Could not transcribe audio. Try clearer audio.")
with tab2:
st.header("Record Your Voice")
st.subheader("Browser-Based Recorder")
audio_data = custom_audio_recorder()
if audio_data and st.button("Analyze Recording", key="analyze_rec"):
with st.spinner("Processing recording..."):
temp_audio_path = process_base64_audio(audio_data)
if temp_audio_path:
transcribed_text = transcribe_audio(temp_audio_path)
if transcribed_text:
display_analysis_results(transcribed_text)
else:
st.error("Could not transcribe audio. Speak clearly.")
st.subheader("Manual Text Input")
manual_text = st.text_area("Enter text to analyze:", placeholder="Type your text...")
if st.button("Analyze Text", key="analyze_manual") and manual_text:
display_analysis_results(manual_text)
show_model_info()
if __name__ == "__main__":
main()