Diggz10 commited on
Commit
1bbfae6
·
verified ·
1 Parent(s): e92f5d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -45
app.py CHANGED
@@ -17,72 +17,60 @@ except Exception as e:
17
  else:
18
  model_load_error = None
19
 
20
- # --- Gradio Prediction Function ---
21
- def predict_emotion(audio_file):
22
- if classifier is None:
23
- return {"error": f"Model load failed: {model_load_error}"}
24
- if audio_file is None:
25
- return {"error": "No audio input provided."}
26
-
27
- try:
28
- if isinstance(audio_file, str):
29
- audio_path = audio_file
30
- elif isinstance(audio_file, tuple):
31
- sample_rate, audio_array = audio_file
32
- temp_audio_path = "temp_audio.wav"
33
- sf.write(temp_audio_path, audio_array, sample_rate)
34
- audio_path = temp_audio_path
35
- else:
36
- return {"error": f"Unsupported input type: {type(audio_file)}"}
37
-
38
- results = classifier(audio_path, top_k=5)
39
- return {item['label']: round(item['score'], 3) for item in results}
40
- except Exception as e:
41
- return {"error": f"Prediction error: {str(e)}"}
42
- finally:
43
- if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
44
- os.remove(temp_audio_path)
45
-
46
- # --- FastAPI App for Base64 API ---
47
  app = FastAPI()
48
 
49
  @app.post("/api/predict/")
50
  async def predict_emotion_api(request: Request):
51
  if classifier is None:
52
- return JSONResponse(content={"error": f"Model load failed: {model_load_error}"}, status_code=500)
53
 
54
  try:
55
  body = await request.json()
56
- base64_audio = body.get("data")
57
- if not base64_audio:
58
- return JSONResponse(content={"error": "Missing 'data' field with base64 audio."}, status_code=400)
59
 
60
- audio_data = base64.b64decode(base64_audio)
 
 
 
 
 
 
 
 
 
 
 
61
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
62
  temp_file.write(audio_data)
63
  temp_audio_path = temp_file.name
64
 
65
  results = classifier(temp_audio_path, top_k=5)
66
- os.unlink(temp_audio_path)
 
 
 
67
 
68
- return {item['label']: round(item['score'], 3) for item in results}
69
  except Exception as e:
70
- return JSONResponse(content={"error": f"API prediction failed: {str(e)}"}, status_code=500)
 
 
 
 
 
 
 
71
 
72
- # --- Gradio UI ---
73
  gradio_interface = gr.Interface(
74
- fn=predict_emotion,
75
  inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
76
  outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
77
  title="Audio Emotion Detector",
78
- description="Upload or record your voice to detect emotions.",
79
  allow_flagging="never"
80
  )
81
 
82
- # --- Mount Gradio inside FastAPI ---
83
- app = gr.mount_gradio_app(app, gradio_interface, path="/")
84
-
85
- # --- Launch for local/dev use only ---
86
- if __name__ == "__main__":
87
- gradio_interface.queue()
88
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
17
  else:
18
  model_load_error = None
19
 
20
+ # --- FastAPI App for a dedicated, robust API ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  app = FastAPI()
22
 
23
  @app.post("/api/predict/")
24
  async def predict_emotion_api(request: Request):
25
  if classifier is None:
26
+ return JSONResponse(content={"error": f"Model is not loaded: {model_load_error}"}, status_code=503)
27
 
28
  try:
29
  body = await request.json()
30
+ # The JS FileReader sends a string like "data:audio/wav;base64,AABBCC..."
31
+ base64_with_prefix = body.get("data")
 
32
 
33
+ if not base64_with_prefix:
34
+ return JSONResponse(content={"error": "Missing 'data' field in request body."}, status_code=400)
35
+
36
+ # Robustly strip the prefix to get the pure base64 data
37
+ try:
38
+ # Find the comma that separates the prefix from the data
39
+ header, encoded = base64_with_prefix.split(",", 1)
40
+ audio_data = base64.b64decode(encoded)
41
+ except (ValueError, TypeError):
42
+ return JSONResponse(content={"error": "Invalid base64 data format."}, status_code=400)
43
+
44
+ # Write to a temporary file for the pipeline
45
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
46
  temp_file.write(audio_data)
47
  temp_audio_path = temp_file.name
48
 
49
  results = classifier(temp_audio_path, top_k=5)
50
+ os.unlink(temp_audio_path) # Clean up the temp file
51
+
52
+ # Return a successful response
53
+ return JSONResponse(content={"data": results})
54
 
 
55
  except Exception as e:
56
+ return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500)
57
+
58
+ # --- Gradio UI function (optional, for the direct Space page) ---
59
+ def gradio_predict_wrapper(audio_file):
60
+ # This is just for the UI on the Hugging Face page itself
61
+ if audio_file is None: return {"error": "Please provide an audio file."}
62
+ results = classifier(audio_file, top_k=5)
63
+ return {item['label']: round(item['score'], 3) for item in results}
64
 
 
65
  gradio_interface = gr.Interface(
66
+ fn=gradio_predict_wrapper,
67
  inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
68
  outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
69
  title="Audio Emotion Detector",
70
+ description="This UI is for direct demonstration. The primary API is at /api/predict/",
71
  allow_flagging="never"
72
  )
73
 
74
+ # --- Mount the Gradio UI onto the FastAPI app ---
75
+ # The API at /api/predict/ will work even if the UI is at a different path.
76
+ app = gr.mount_gradio_app(app, gradio_interface, path="/ui")