MaroofTechSorcerer commited on
Commit
1949646
·
verified ·
1 Parent(s): 7de734d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -33
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import streamlit as st
3
  import tempfile
4
  import torch
5
- import torchaudio
6
  import transformers
7
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
8
  import plotly.express as px
@@ -15,6 +14,15 @@ import asyncio
15
  from concurrent.futures import ThreadPoolExecutor
16
  import streamlit.components.v1 as components
17
 
 
 
 
 
 
 
 
 
 
18
  # Suppress warnings
19
  logging.getLogger("torch").setLevel(logging.ERROR)
20
  logging.getLogger("transformers").setLevel(logging.ERROR)
@@ -33,21 +41,25 @@ st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from v
33
  # Global model cache
34
  @st.cache_resource
35
  def load_models():
36
- whisper_model = whisper.load_model("base")
37
-
38
- emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
39
- emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
40
- emotion_model = emotion_model.to(device).half()
41
- emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
42
- top_k=None, device=0 if torch.cuda.is_available() else -1)
 
43
 
44
- sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
45
- sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
46
- sarcasm_model = sarcasm_model.to(device).half()
47
- sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
48
- device=0 if torch.cuda.is_available() else -1)
49
-
50
- return whisper_model, emotion_classifier, sarcasm_classifier
 
 
 
51
 
52
  whisper_model, emotion_classifier, sarcasm_classifier = load_models()
53
 
@@ -90,29 +102,45 @@ async def perform_sarcasm_detection(text):
90
  # Audio validation
91
  def validate_audio(audio_path):
92
  try:
93
- waveform, sample_rate = torchaudio.load(audio_path)
94
- if waveform.abs().max() < 0.01:
95
- st.warning("Audio volume too low.")
96
- return False
97
- if waveform.shape[1] / sample_rate < 1:
98
- st.warning("Audio too short.")
99
- return False
 
 
 
 
 
 
 
 
 
100
  return True
101
- except:
102
- st.error("Invalid audio file.")
103
  return False
104
 
105
  # Audio transcription
106
  @st.cache_data
107
  def transcribe_audio(audio_path):
108
  try:
109
- waveform, sample_rate = torchaudio.load(audio_path)
110
- if sample_rate != 16000:
111
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
112
- waveform = resampler(waveform)
113
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
114
- torchaudio.save(temp_file.name, waveform, 16000)
115
- result = whisper_model.transcribe(temp_file.name, language="en")
 
 
 
 
 
 
 
116
  os.remove(temp_file.name)
117
  return result["text"].strip()
118
  except Exception as e:
@@ -283,7 +311,8 @@ def main():
283
  display_analysis_results(text)
284
  else:
285
  st.error("Transcription failed.")
286
- os.remove(temp_path)
 
287
  progress.empty()
288
 
289
  with tab2:
@@ -300,7 +329,8 @@ def main():
300
  display_analysis_results(text)
301
  else:
302
  st.error("Transcription failed.")
303
- os.remove(temp_path)
 
304
  progress.empty()
305
 
306
  with tab3:
 
2
  import streamlit as st
3
  import tempfile
4
  import torch
 
5
  import transformers
6
  from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
7
  import plotly.express as px
 
14
  from concurrent.futures import ThreadPoolExecutor
15
  import streamlit.components.v1 as components
16
 
17
+ # Try importing torchaudio, fallback to pydub
18
+ try:
19
+ import torchaudio
20
+ USE_TORCHAUDIO = True
21
+ except ImportError:
22
+ from pydub import AudioSegment
23
+ USE_TORCHAUDIO = False
24
+ st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio")
25
+
26
  # Suppress warnings
27
  logging.getLogger("torch").setLevel(logging.ERROR)
28
  logging.getLogger("transformers").setLevel(logging.ERROR)
 
41
  # Global model cache
42
  @st.cache_resource
43
  def load_models():
44
+ try:
45
+ whisper_model = whisper.load_model("base")
46
+
47
+ emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
48
+ emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
49
+ emotion_model = emotion_model.to(device).half()
50
+ emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
51
+ top_k=None, device=0 if torch.cuda.is_available() else -1)
52
 
53
+ sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
54
+ sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
55
+ sarcasm_model = sarcasm_model.to(device).half()
56
+ sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
57
+ device=0 if torch.cuda.is_available() else -1)
58
+
59
+ return whisper_model, emotion_classifier, sarcasm_classifier
60
+ except Exception as e:
61
+ st.error(f"Failed to load models: {str(e)}")
62
+ raise
63
 
64
  whisper_model, emotion_classifier, sarcasm_classifier = load_models()
65
 
 
102
  # Audio validation
103
  def validate_audio(audio_path):
104
  try:
105
+ if USE_TORCHAUDIO:
106
+ waveform, sample_rate = torchaudio.load(audio_path)
107
+ if waveform.abs().max() < 0.01:
108
+ st.warning("Audio volume too low.")
109
+ return False
110
+ if waveform.shape[1] / sample_rate < 1:
111
+ st.warning("Audio too short.")
112
+ return False
113
+ else:
114
+ sound = AudioSegment.from_file(audio_path)
115
+ if sound.dBFS < -55:
116
+ st.warning("Audio volume too low.")
117
+ return False
118
+ if len(sound) < 1000:
119
+ st.warning("Audio too short.")
120
+ return False
121
  return True
122
+ except Exception as e:
123
+ st.error(f"Invalid audio file: {str(e)}")
124
  return False
125
 
126
  # Audio transcription
127
  @st.cache_data
128
  def transcribe_audio(audio_path):
129
  try:
130
+ if USE_TORCHAUDIO:
131
+ waveform, sample_rate = torchaudio.load(audio_path)
132
+ if sample_rate != 16000:
133
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
134
+ waveformვ: waveform = resampler(waveform)
135
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
136
+ torchaudio.save(temp_file.name, waveform, 16000)
137
+ result = whisper_model.transcribe(temp_file.name, language="en")
138
+ else:
139
+ sound = AudioSegment.from_file(audio_path)
140
+ sound = sound.set_frame_rate(16000).set_channels(1)
141
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
142
+ sound.export(temp_file.name, format="wav")
143
+ result = whisper_model.transcribe(temp_file.name, language="en")
144
  os.remove(temp_file.name)
145
  return result["text"].strip()
146
  except Exception as e:
 
311
  display_analysis_results(text)
312
  else:
313
  st.error("Transcription failed.")
314
+ if os.path.exists(temp_path):
315
+ os.remove(temp_path)
316
  progress.empty()
317
 
318
  with tab2:
 
329
  display_analysis_results(text)
330
  else:
331
  st.error("Transcription failed.")
332
+ if os.path.exists(temp_path):
333
+ os.remove(temp_path)
334
  progress.empty()
335
 
336
  with tab3: