Update app.py
Browse files
app.py
CHANGED
@@ -23,16 +23,12 @@ except ImportError:
|
|
23 |
USE_TORCHAUDIO = False
|
24 |
st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio")
|
25 |
|
26 |
-
# Suppress warnings
|
27 |
logging.getLogger("torch").setLevel(logging.ERROR)
|
28 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
29 |
warnings.filterwarnings("ignore")
|
30 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
31 |
|
32 |
-
# Device setup
|
33 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
-
st.write(f"Using device: {device}")
|
35 |
-
|
36 |
# Streamlit config
|
37 |
st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis")
|
38 |
st.title("π Voice Sentiment Analysis")
|
@@ -42,19 +38,20 @@ st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from v
|
|
42 |
@st.cache_resource
|
43 |
def load_models():
|
44 |
try:
|
|
|
45 |
whisper_model = whisper.load_model("base")
|
46 |
|
|
|
47 |
emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
48 |
emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
49 |
-
emotion_model = emotion_model.to(device).half()
|
50 |
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
|
51 |
-
top_k=None, device
|
52 |
|
|
|
53 |
sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
|
54 |
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
|
55 |
-
sarcasm_model = sarcasm_model.to(device).half()
|
56 |
sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
|
57 |
-
device
|
58 |
|
59 |
return whisper_model, emotion_classifier, sarcasm_classifier
|
60 |
except Exception as e:
|
@@ -72,7 +69,7 @@ async def perform_emotion_detection(text):
|
|
72 |
results = emotion_classifier(text)[0]
|
73 |
emotions_dict = {r['label']: r['score'] for r in results}
|
74 |
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
|
75 |
-
top_emotion = max(filtered_emotions, key=filtered_emotions.get)
|
76 |
|
77 |
positive_emotions = ["joy"]
|
78 |
negative_emotions = ["anger", "disgust", "fear", "sadness"]
|
@@ -131,16 +128,16 @@ def transcribe_audio(audio_path):
|
|
131 |
waveform, sample_rate = torchaudio.load(audio_path)
|
132 |
if sample_rate != 16000:
|
133 |
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
134 |
-
waveform
|
135 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
136 |
torchaudio.save(temp_file.name, waveform, 16000)
|
137 |
-
result = whisper_model.transcribe(temp_file.name, language="en")
|
138 |
else:
|
139 |
sound = AudioSegment.from_file(audio_path)
|
140 |
sound = sound.set_frame_rate(16000).set_channels(1)
|
141 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
142 |
sound.export(temp_file.name, format="wav")
|
143 |
-
result = whisper_model.transcribe(temp_file.name, language="en")
|
144 |
os.remove(temp_file.name)
|
145 |
return result["text"].strip()
|
146 |
except Exception as e:
|
@@ -168,6 +165,9 @@ def process_uploaded_audio(audio_file):
|
|
168 |
# Process base64 audio
|
169 |
def process_base64_audio(base64_data):
|
170 |
try:
|
|
|
|
|
|
|
171 |
base64_binary = base64_data.split(',')[1]
|
172 |
binary_data = base64.b64decode(base64_binary)
|
173 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
@@ -339,5 +339,4 @@ def main():
|
|
339 |
display_analysis_results(manual_text)
|
340 |
|
341 |
if __name__ == "__main__":
|
342 |
-
main()
|
343 |
-
torch.cuda.empty_cache()
|
|
|
23 |
USE_TORCHAUDIO = False
|
24 |
st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio")
|
25 |
|
26 |
+
# Suppress warnings and set logging
|
27 |
logging.getLogger("torch").setLevel(logging.ERROR)
|
28 |
logging.getLogger("transformers").setLevel(logging.ERROR)
|
29 |
warnings.filterwarnings("ignore")
|
30 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
31 |
|
|
|
|
|
|
|
|
|
32 |
# Streamlit config
|
33 |
st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis")
|
34 |
st.title("π Voice Sentiment Analysis")
|
|
|
38 |
@st.cache_resource
|
39 |
def load_models():
|
40 |
try:
|
41 |
+
# Load Whisper model with CPU optimization
|
42 |
whisper_model = whisper.load_model("base")
|
43 |
|
44 |
+
# Load emotion detection model
|
45 |
emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
46 |
emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
|
|
|
47 |
emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
|
48 |
+
top_k=None, device=-1) # CPU only
|
49 |
|
50 |
+
# Load sarcasm detection model
|
51 |
sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
|
52 |
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
|
|
|
53 |
sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
|
54 |
+
device=-1) # CPU only
|
55 |
|
56 |
return whisper_model, emotion_classifier, sarcasm_classifier
|
57 |
except Exception as e:
|
|
|
69 |
results = emotion_classifier(text)[0]
|
70 |
emotions_dict = {r['label']: r['score'] for r in results}
|
71 |
filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
|
72 |
+
top_emotion = max(filtered_emotions, key=filtered_emotions.get, default="neutral")
|
73 |
|
74 |
positive_emotions = ["joy"]
|
75 |
negative_emotions = ["anger", "disgust", "fear", "sadness"]
|
|
|
128 |
waveform, sample_rate = torchaudio.load(audio_path)
|
129 |
if sample_rate != 16000:
|
130 |
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
131 |
+
waveform = resampler(waveform)
|
132 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
133 |
torchaudio.save(temp_file.name, waveform, 16000)
|
134 |
+
result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6)
|
135 |
else:
|
136 |
sound = AudioSegment.from_file(audio_path)
|
137 |
sound = sound.set_frame_rate(16000).set_channels(1)
|
138 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
139 |
sound.export(temp_file.name, format="wav")
|
140 |
+
result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6)
|
141 |
os.remove(temp_file.name)
|
142 |
return result["text"].strip()
|
143 |
except Exception as e:
|
|
|
165 |
# Process base64 audio
|
166 |
def process_base64_audio(base64_data):
|
167 |
try:
|
168 |
+
if not base64_data.startswith("data:audio"):
|
169 |
+
st.error("Invalid audio data.")
|
170 |
+
return None
|
171 |
base64_binary = base64_data.split(',')[1]
|
172 |
binary_data = base64.b64decode(base64_binary)
|
173 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
|
|
339 |
display_analysis_results(manual_text)
|
340 |
|
341 |
if __name__ == "__main__":
|
342 |
+
main()
|
|