MaroofTechSorcerer commited on
Commit
c4f2255
Β·
verified Β·
1 Parent(s): 51c2389

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -23,16 +23,12 @@ except ImportError:
23
  USE_TORCHAUDIO = False
24
  st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio")
25
 
26
- # Suppress warnings
27
  logging.getLogger("torch").setLevel(logging.ERROR)
28
  logging.getLogger("transformers").setLevel(logging.ERROR)
29
  warnings.filterwarnings("ignore")
30
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
 
32
- # Device setup
33
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- st.write(f"Using device: {device}")
35
-
36
  # Streamlit config
37
  st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis")
38
  st.title("πŸŽ™ Voice Sentiment Analysis")
@@ -42,19 +38,20 @@ st.markdown("Fast, accurate detection of emotions, sentiment, and sarcasm from v
42
  @st.cache_resource
43
  def load_models():
44
  try:
 
45
  whisper_model = whisper.load_model("base")
46
 
 
47
  emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
48
  emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
49
- emotion_model = emotion_model.to(device).half()
50
  emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
51
- top_k=None, device=0 if torch.cuda.is_available() else -1)
52
 
 
53
  sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
54
  sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
55
- sarcasm_model = sarcasm_model.to(device).half()
56
  sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
57
- device=0 if torch.cuda.is_available() else -1)
58
 
59
  return whisper_model, emotion_classifier, sarcasm_classifier
60
  except Exception as e:
@@ -72,7 +69,7 @@ async def perform_emotion_detection(text):
72
  results = emotion_classifier(text)[0]
73
  emotions_dict = {r['label']: r['score'] for r in results}
74
  filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
75
- top_emotion = max(filtered_emotions, key=filtered_emotions.get)
76
 
77
  positive_emotions = ["joy"]
78
  negative_emotions = ["anger", "disgust", "fear", "sadness"]
@@ -131,16 +128,16 @@ def transcribe_audio(audio_path):
131
  waveform, sample_rate = torchaudio.load(audio_path)
132
  if sample_rate != 16000:
133
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
134
- waveformვ: waveform = resampler(waveform)
135
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
136
  torchaudio.save(temp_file.name, waveform, 16000)
137
- result = whisper_model.transcribe(temp_file.name, language="en")
138
  else:
139
  sound = AudioSegment.from_file(audio_path)
140
  sound = sound.set_frame_rate(16000).set_channels(1)
141
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
142
  sound.export(temp_file.name, format="wav")
143
- result = whisper_model.transcribe(temp_file.name, language="en")
144
  os.remove(temp_file.name)
145
  return result["text"].strip()
146
  except Exception as e:
@@ -168,6 +165,9 @@ def process_uploaded_audio(audio_file):
168
  # Process base64 audio
169
  def process_base64_audio(base64_data):
170
  try:
 
 
 
171
  base64_binary = base64_data.split(',')[1]
172
  binary_data = base64.b64decode(base64_binary)
173
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
@@ -339,5 +339,4 @@ def main():
339
  display_analysis_results(manual_text)
340
 
341
  if __name__ == "__main__":
342
- main()
343
- torch.cuda.empty_cache()
 
23
  USE_TORCHAUDIO = False
24
  st.warning("torchaudio not found. Using pydub (slower). Install torchaudio: pip install torchaudio")
25
 
26
+ # Suppress warnings and set logging
27
  logging.getLogger("torch").setLevel(logging.ERROR)
28
  logging.getLogger("transformers").setLevel(logging.ERROR)
29
  warnings.filterwarnings("ignore")
30
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
 
 
 
 
 
32
  # Streamlit config
33
  st.set_page_config(layout="wide", page_title="Voice Sentiment Analysis")
34
  st.title("πŸŽ™ Voice Sentiment Analysis")
 
38
  @st.cache_resource
39
  def load_models():
40
  try:
41
+ # Load Whisper model with CPU optimization
42
  whisper_model = whisper.load_model("base")
43
 
44
+ # Load emotion detection model
45
  emotion_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
46
  emotion_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/distilbert-base-uncased-emotion")
 
47
  emotion_classifier = pipeline("text-classification", model=emotion_model, tokenizer=emotion_tokenizer,
48
+ top_k=None, device=-1) # CPU only
49
 
50
+ # Load sarcasm detection model
51
  sarcasm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
52
  sarcasm_model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-irony")
 
53
  sarcasm_classifier = pipeline("text-classification", model=sarcasm_model, tokenizer=sarcasm_tokenizer,
54
+ device=-1) # CPU only
55
 
56
  return whisper_model, emotion_classifier, sarcasm_classifier
57
  except Exception as e:
 
69
  results = emotion_classifier(text)[0]
70
  emotions_dict = {r['label']: r['score'] for r in results}
71
  filtered_emotions = {k: v for k, v in emotions_dict.items() if v > 0.01}
72
+ top_emotion = max(filtered_emotions, key=filtered_emotions.get, default="neutral")
73
 
74
  positive_emotions = ["joy"]
75
  negative_emotions = ["anger", "disgust", "fear", "sadness"]
 
128
  waveform, sample_rate = torchaudio.load(audio_path)
129
  if sample_rate != 16000:
130
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
131
+ waveform = resampler(waveform)
132
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
133
  torchaudio.save(temp_file.name, waveform, 16000)
134
+ result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6)
135
  else:
136
  sound = AudioSegment.from_file(audio_path)
137
  sound = sound.set_frame_rate(16000).set_channels(1)
138
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
139
  sound.export(temp_file.name, format="wav")
140
+ result = whisper_model.transcribe(temp_file.name, language="en", no_speech_threshold=0.6)
141
  os.remove(temp_file.name)
142
  return result["text"].strip()
143
  except Exception as e:
 
165
  # Process base64 audio
166
  def process_base64_audio(base64_data):
167
  try:
168
+ if not base64_data.startswith("data:audio"):
169
+ st.error("Invalid audio data.")
170
+ return None
171
  base64_binary = base64_data.split(',')[1]
172
  binary_data = base64.b64decode(base64_binary)
173
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
 
339
  display_analysis_results(manual_text)
340
 
341
  if __name__ == "__main__":
342
+ main()