Diggz10 commited on
Commit
ab12152
·
verified ·
1 Parent(s): ccaa441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -22
app.py CHANGED
@@ -2,38 +2,119 @@ import gradio as gr
2
  from transformers import pipeline
3
  import soundfile as sf
4
  import os
 
5
 
 
 
 
 
 
6
  try:
7
- classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
 
 
 
 
 
 
8
  except Exception as e:
9
- def error_fn(audio_file):
10
- return {"error": f"Failed to load the model. Please check the logs. Error: {str(e)}"}
11
- classifier = None
 
 
 
 
 
 
 
 
 
12
 
13
  def predict_emotion(audio_file):
14
- if classifier is None: return {"error": "The AI model could not be loaded."}
15
- if audio_file is None: return {"error": "No audio input provided."}
16
- if isinstance(audio_file, str): audio_path = audio_file
17
- elif isinstance(audio_file, tuple):
18
- sample_rate, audio_array = audio_file
19
- temp_audio_path = "temp_audio_from_mic.wav"
20
- sf.write(temp_audio_path, audio_array, sample_rate)
21
- audio_path = temp_audio_path
22
- else: return {"error": f"Invalid audio input format: {type(audio_file)}"}
 
 
23
  try:
24
- results = classifier(audio_path, top_k=5)
25
- return {item['label']: round(item['score'], 3) for item in results}
26
- except Exception as e: return {"error": f"An error occurred during prediction: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  finally:
28
- if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path): os.remove(temp_audio_path)
 
 
 
 
 
 
29
 
 
30
  iface = gr.Interface(
31
  fn=predict_emotion,
32
- inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record with Microphone"),
33
- outputs=gr.Label(num_top_classes=5, label="Emotion Probabilities"),
34
- title="AI Audio Emotion Detector",
35
- description="Upload an audio file or record your voice to detect emotions.",
 
 
 
 
 
 
 
 
 
 
36
  )
37
 
38
  if __name__ == "__main__":
39
- iface.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import pipeline
3
  import soundfile as sf
4
  import os
5
+ import logging
6
 
7
+ # Set up logging to help debug issues
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Initialize the classifier with error handling
12
  try:
13
+ # Using a more reliable emotion classification model
14
+ classifier = pipeline(
15
+ "audio-classification",
16
+ model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",
17
+ return_all_scores=True
18
+ )
19
+ logger.info("Model loaded successfully")
20
  except Exception as e:
21
+ logger.error(f"Failed to load primary model: {e}")
22
+ try:
23
+ # Fallback to a different model
24
+ classifier = pipeline(
25
+ "audio-classification",
26
+ model="superb/wav2vec2-base-superb-er",
27
+ return_all_scores=True
28
+ )
29
+ logger.info("Fallback model loaded successfully")
30
+ except Exception as e2:
31
+ logger.error(f"Failed to load fallback model: {e2}")
32
+ classifier = None
33
 
34
  def predict_emotion(audio_file):
35
+ """
36
+ Predict emotion from audio file
37
+ """
38
+ if classifier is None:
39
+ return {"error": "The AI model could not be loaded. Please check the logs."}
40
+
41
+ if audio_file is None:
42
+ return {"error": "No audio input provided."}
43
+
44
+ temp_audio_path = None
45
+
46
  try:
47
+ # Handle different input types
48
+ if isinstance(audio_file, str):
49
+ audio_path = audio_file
50
+ elif isinstance(audio_file, tuple):
51
+ sample_rate, audio_array = audio_file
52
+ temp_audio_path = "temp_audio_from_mic.wav"
53
+ sf.write(temp_audio_path, audio_array, sample_rate)
54
+ audio_path = temp_audio_path
55
+ else:
56
+ return {"error": f"Invalid audio input format: {type(audio_file)}"}
57
+
58
+ # Check if file exists
59
+ if not os.path.exists(audio_path):
60
+ return {"error": "Audio file not found"}
61
+
62
+ # Perform emotion classification
63
+ logger.info(f"Processing audio file: {audio_path}")
64
+ results = classifier(audio_path)
65
+
66
+ # Process results
67
+ if isinstance(results, list) and len(results) > 0:
68
+ # Sort by score and return top 5
69
+ sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)[:5]
70
+ emotion_scores = {item['label']: round(item['score'], 3) for item in sorted_results}
71
+ else:
72
+ return {"error": "No valid results from the model"}
73
+
74
+ logger.info(f"Prediction successful: {emotion_scores}")
75
+ return emotion_scores
76
+
77
+ except Exception as e:
78
+ logger.error(f"Error during prediction: {str(e)}")
79
+ return {"error": f"An error occurred during prediction: {str(e)}"}
80
+
81
  finally:
82
+ # Clean up temporary file
83
+ if temp_audio_path and os.path.exists(temp_audio_path):
84
+ try:
85
+ os.remove(temp_audio_path)
86
+ logger.info("Temporary audio file cleaned up")
87
+ except Exception as e:
88
+ logger.warning(f"Failed to clean up temp file: {e}")
89
 
90
+ # Create Gradio interface
91
  iface = gr.Interface(
92
  fn=predict_emotion,
93
+ inputs=gr.Audio(
94
+ sources=["microphone", "upload"],
95
+ type="filepath",
96
+ label="Upload Audio or Record with Microphone"
97
+ ),
98
+ outputs=gr.Label(
99
+ num_top_classes=5,
100
+ label="Emotion Probabilities"
101
+ ),
102
+ title="🎵 AI Audio Emotion Detector",
103
+ description="Upload an audio file or record your voice to detect emotions. Supported formats: WAV, MP3, M4A, FLAC.",
104
+ article="This tool uses advanced AI models to analyze emotional content in speech and audio.",
105
+ examples=None, # You can add example audio files here if you have them
106
+ allow_flagging="never"
107
  )
108
 
109
  if __name__ == "__main__":
110
+ try:
111
+ # Launch with queue for better handling of concurrent requests
112
+ iface.queue(max_size=10).launch(
113
+ server_name="0.0.0.0", # Allow external access
114
+ server_port=7860, # Default Gradio port
115
+ share=True, # Create a public link
116
+ debug=True # Enable debug mode
117
+ )
118
+ except Exception as e:
119
+ logger.error(f"Failed to launch Gradio app: {e}")
120
+ print(f"Error launching app: {e}")