Diggz10 commited on
Commit
7a18e70
·
verified ·
1 Parent(s): 5cde27f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -52
app.py CHANGED
@@ -4,83 +4,85 @@ import soundfile as sf
4
  import os
5
  import base64
6
  import tempfile
 
 
 
7
 
8
- # --- Model Loading ---
9
  try:
10
  classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
11
  except Exception as e:
12
- def error_fn(audio_file):
13
- return {"error": f"Failed to load the model. Please check the logs. Error: {str(e)}"}
14
  classifier = None
 
 
 
15
 
16
- # --- Prediction Function ---
17
  def predict_emotion(audio_file):
18
- if classifier is None:
19
- return {"error": "The AI model could not be loaded."}
20
-
21
- if audio_file is None:
22
  return {"error": "No audio input provided."}
23
-
24
- # Handle different input types
25
- if isinstance(audio_file, str):
26
- audio_path = audio_file
27
- elif isinstance(audio_file, tuple):
28
- sample_rate, audio_array = audio_file
29
- temp_audio_path = "temp_audio_from_mic.wav"
30
- sf.write(temp_audio_path, audio_array, sample_rate)
31
- audio_path = temp_audio_path
32
- else:
33
- return {"error": f"Invalid audio input format: {type(audio_file)}"}
34
-
35
  try:
 
 
 
 
 
 
 
 
 
 
36
  results = classifier(audio_path, top_k=5)
37
  return {item['label']: round(item['score'], 3) for item in results}
38
- except Exception as e:
39
- return {"error": f"An error occurred during prediction: {str(e)}"}
40
  finally:
41
- if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
42
  os.remove(temp_audio_path)
43
 
44
- # --- API Function for Base64 Input ---
45
- def predict_emotion_api(data):
46
- """
47
- API function that accepts base64 encoded audio data
48
- Expected input format: {"data": "base64_encoded_audio_string"}
49
- """
50
  if classifier is None:
51
- return {"error": "The AI model could not be loaded."}
52
 
53
  try:
54
- # Decode base64 audio data
55
- audio_data = base64.b64decode(data)
56
-
57
- # Create temporary file
58
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
 
 
59
  temp_file.write(audio_data)
60
  temp_audio_path = temp_file.name
61
-
62
- # Predict emotion
63
  results = classifier(temp_audio_path, top_k=5)
64
-
65
- # Clean up temp file
66
  os.unlink(temp_audio_path)
67
-
68
  return {item['label']: round(item['score'], 3) for item in results}
69
-
70
  except Exception as e:
71
- return {"error": f"An error occurred during prediction: {str(e)}"}
72
 
73
- # --- Gradio Interface ---
74
- # Main interface for web UI
75
- iface = gr.Interface(
76
  fn=predict_emotion,
77
- inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record with Microphone"),
78
- outputs=gr.Label(num_top_classes=5, label="Emotion Probabilities"),
79
- title="AI Audio Emotion Detector",
80
- description="Upload an audio file or record your voice to detect emotions.",
81
- api_name="predict" # This creates /api/predict/ endpoint
82
  )
83
 
84
- # Launch the Gradio app
 
 
 
85
  if __name__ == "__main__":
86
- iface.queue().launch(server_name="0.0.0.0", share=True)
 
 
4
  import os
5
  import base64
6
  import tempfile
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.responses import JSONResponse
9
+ import uvicorn
10
 
11
+ # --- Load Model ---
12
  try:
13
  classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
14
  except Exception as e:
 
 
15
  classifier = None
16
+ model_load_error = str(e)
17
+ else:
18
+ model_load_error = None
19
 
20
+ # --- Gradio Prediction Function ---
21
  def predict_emotion(audio_file):
22
+ if classifier is None:
23
+ return {"error": f"Model load failed: {model_load_error}"}
24
+ if audio_file is None:
 
25
  return {"error": "No audio input provided."}
26
+
 
 
 
 
 
 
 
 
 
 
 
27
  try:
28
+ if isinstance(audio_file, str):
29
+ audio_path = audio_file
30
+ elif isinstance(audio_file, tuple):
31
+ sample_rate, audio_array = audio_file
32
+ temp_audio_path = "temp_audio.wav"
33
+ sf.write(temp_audio_path, audio_array, sample_rate)
34
+ audio_path = temp_audio_path
35
+ else:
36
+ return {"error": f"Unsupported input type: {type(audio_file)}"}
37
+
38
  results = classifier(audio_path, top_k=5)
39
  return {item['label']: round(item['score'], 3) for item in results}
40
+ except Exception as e:
41
+ return {"error": f"Prediction error: {str(e)}"}
42
  finally:
43
+ if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
44
  os.remove(temp_audio_path)
45
 
46
+ # --- FastAPI App for Base64 API ---
47
+ app = FastAPI()
48
+
49
+ @app.post("/api/predict/")
50
+ async def predict_emotion_api(request: Request):
 
51
  if classifier is None:
52
+ return JSONResponse(content={"error": f"Model load failed: {model_load_error}"}, status_code=500)
53
 
54
  try:
55
+ body = await request.json()
56
+ base64_audio = body.get("data")
57
+ if not base64_audio:
58
+ return JSONResponse(content={"error": "Missing 'data' field with base64 audio."}, status_code=400)
59
+
60
+ audio_data = base64.b64decode(base64_audio)
61
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
62
  temp_file.write(audio_data)
63
  temp_audio_path = temp_file.name
64
+
 
65
  results = classifier(temp_audio_path, top_k=5)
 
 
66
  os.unlink(temp_audio_path)
67
+
68
  return {item['label']: round(item['score'], 3) for item in results}
 
69
  except Exception as e:
70
+ return JSONResponse(content={"error": f"API prediction failed: {str(e)}"}, status_code=500)
71
 
72
+ # --- Gradio UI ---
73
+ gradio_interface = gr.Interface(
 
74
  fn=predict_emotion,
75
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
76
+ outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
77
+ title="Audio Emotion Detector",
78
+ description="Upload or record your voice to detect emotions.",
79
+ allow_flagging="never"
80
  )
81
 
82
+ # --- Mount Gradio inside FastAPI ---
83
+ app = gr.mount_gradio_app(app, gradio_interface, path="/")
84
+
85
+ # --- Launch for local/dev use only ---
86
  if __name__ == "__main__":
87
+ gradio_interface.queue()
88
+ uvicorn.run(app, host="0.0.0.0", port=7860)