Diggz10 commited on
Commit
5cde27f
·
verified ·
1 Parent(s): 4c23f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -9
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  from transformers import pipeline
3
  import soundfile as sf
4
  import os
 
 
5
 
6
  # --- Model Loading ---
7
  try:
@@ -13,33 +15,72 @@ except Exception as e:
13
 
14
  # --- Prediction Function ---
15
  def predict_emotion(audio_file):
16
- if classifier is None: return {"error": "The AI model could not be loaded."}
17
- if audio_file is None: return {"error": "No audio input provided."}
18
- if isinstance(audio_file, str): audio_path = audio_file
 
 
 
 
 
 
19
  elif isinstance(audio_file, tuple):
20
  sample_rate, audio_array = audio_file
21
  temp_audio_path = "temp_audio_from_mic.wav"
22
  sf.write(temp_audio_path, audio_array, sample_rate)
23
  audio_path = temp_audio_path
24
- else: return {"error": f"Invalid audio input format: {type(audio_file)}"}
 
 
25
  try:
26
  results = classifier(audio_path, top_k=5)
27
  return {item['label']: round(item['score'], 3) for item in results}
28
- except Exception as e: return {"error": f"An error occurred during prediction: {str(e)}"}
 
29
  finally:
30
- if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path): os.remove(temp_audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # --- Gradio Interface ---
 
33
  iface = gr.Interface(
34
  fn=predict_emotion,
35
  inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record with Microphone"),
36
  outputs=gr.Label(num_top_classes=5, label="Emotion Probabilities"),
37
  title="AI Audio Emotion Detector",
38
  description="Upload an audio file or record your voice to detect emotions.",
39
- # THIS LINE IS CRITICAL - WE ARE CREATING AN EXPLICIT API ENDPOINT
40
- api_name="predict"
41
  )
42
 
43
- # Launch the Gradio app with explicit server settings
44
  if __name__ == "__main__":
45
  iface.queue().launch(server_name="0.0.0.0", share=True)
 
2
  from transformers import pipeline
3
  import soundfile as sf
4
  import os
5
+ import base64
6
+ import tempfile
7
 
8
  # --- Model Loading ---
9
  try:
 
15
 
16
  # --- Prediction Function ---
17
  def predict_emotion(audio_file):
18
+ if classifier is None:
19
+ return {"error": "The AI model could not be loaded."}
20
+
21
+ if audio_file is None:
22
+ return {"error": "No audio input provided."}
23
+
24
+ # Handle different input types
25
+ if isinstance(audio_file, str):
26
+ audio_path = audio_file
27
  elif isinstance(audio_file, tuple):
28
  sample_rate, audio_array = audio_file
29
  temp_audio_path = "temp_audio_from_mic.wav"
30
  sf.write(temp_audio_path, audio_array, sample_rate)
31
  audio_path = temp_audio_path
32
+ else:
33
+ return {"error": f"Invalid audio input format: {type(audio_file)}"}
34
+
35
  try:
36
  results = classifier(audio_path, top_k=5)
37
  return {item['label']: round(item['score'], 3) for item in results}
38
+ except Exception as e:
39
+ return {"error": f"An error occurred during prediction: {str(e)}"}
40
  finally:
41
+ if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
42
+ os.remove(temp_audio_path)
43
+
44
+ # --- API Function for Base64 Input ---
45
+ def predict_emotion_api(data):
46
+ """
47
+ API function that accepts base64 encoded audio data
48
+ Expected input format: {"data": "base64_encoded_audio_string"}
49
+ """
50
+ if classifier is None:
51
+ return {"error": "The AI model could not be loaded."}
52
+
53
+ try:
54
+ # Decode base64 audio data
55
+ audio_data = base64.b64decode(data)
56
+
57
+ # Create temporary file
58
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
59
+ temp_file.write(audio_data)
60
+ temp_audio_path = temp_file.name
61
+
62
+ # Predict emotion
63
+ results = classifier(temp_audio_path, top_k=5)
64
+
65
+ # Clean up temp file
66
+ os.unlink(temp_audio_path)
67
+
68
+ return {item['label']: round(item['score'], 3) for item in results}
69
+
70
+ except Exception as e:
71
+ return {"error": f"An error occurred during prediction: {str(e)}"}
72
 
73
  # --- Gradio Interface ---
74
+ # Main interface for web UI
75
  iface = gr.Interface(
76
  fn=predict_emotion,
77
  inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record with Microphone"),
78
  outputs=gr.Label(num_top_classes=5, label="Emotion Probabilities"),
79
  title="AI Audio Emotion Detector",
80
  description="Upload an audio file or record your voice to detect emotions.",
81
+ api_name="predict" # This creates /api/predict/ endpoint
 
82
  )
83
 
84
+ # Launch the Gradio app
85
  if __name__ == "__main__":
86
  iface.queue().launch(server_name="0.0.0.0", share=True)