Diggz10 commited on
Commit
c2cb49b
·
verified ·
1 Parent(s): 1bbfae6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -15
app.py CHANGED
@@ -17,9 +17,10 @@ except Exception as e:
17
  else:
18
  model_load_error = None
19
 
20
- # --- FastAPI App for a dedicated, robust API ---
21
  app = FastAPI()
22
 
 
23
  @app.post("/api/predict/")
24
  async def predict_emotion_api(request: Request):
25
  if classifier is None:
@@ -39,38 +40,52 @@ async def predict_emotion_api(request: Request):
39
  header, encoded = base64_with_prefix.split(",", 1)
40
  audio_data = base64.b64decode(encoded)
41
  except (ValueError, TypeError):
42
- return JSONResponse(content={"error": "Invalid base64 data format."}, status_code=400)
43
 
44
- # Write to a temporary file for the pipeline
45
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
46
  temp_file.write(audio_data)
47
  temp_audio_path = temp_file.name
48
 
49
- results = classifier(temp_audio_path, top_k=5)
50
  os.unlink(temp_audio_path) # Clean up the temp file
51
 
52
- # Return a successful response
 
 
53
  return JSONResponse(content={"data": results})
54
 
55
  except Exception as e:
 
 
 
56
  return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500)
57
 
58
- # --- Gradio UI function (optional, for the direct Space page) ---
59
- def gradio_predict_wrapper(audio_file):
60
- # This is just for the UI on the Hugging Face page itself
61
- if audio_file is None: return {"error": "Please provide an audio file."}
62
- results = classifier(audio_file, top_k=5)
63
- return {item['label']: round(item['score'], 3) for item in results}
 
 
 
 
 
64
 
65
  gradio_interface = gr.Interface(
66
  fn=gradio_predict_wrapper,
67
  inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
68
  outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
69
  title="Audio Emotion Detector",
70
- description="This UI is for direct demonstration. The primary API is at /api/predict/",
71
  allow_flagging="never"
72
  )
73
 
74
- # --- Mount the Gradio UI onto the FastAPI app ---
75
- # The API at /api/predict/ will work even if the UI is at a different path.
76
- app = gr.mount_gradio_app(app, gradio_interface, path="/ui")
 
 
 
 
 
17
  else:
18
  model_load_error = None
19
 
20
+ # --- FastAPI App ---
21
  app = FastAPI()
22
 
23
+ # This is our dedicated, robust API endpoint
24
  @app.post("/api/predict/")
25
  async def predict_emotion_api(request: Request):
26
  if classifier is None:
 
40
  header, encoded = base64_with_prefix.split(",", 1)
41
  audio_data = base64.b64decode(encoded)
42
  except (ValueError, TypeError):
43
+ return JSONResponse(content={"error": "Invalid base64 data format. Please send the full data URI."}, status_code=400)
44
 
45
+ # Write to a temporary file for the pipeline to process
46
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
47
  temp_file.write(audio_data)
48
  temp_audio_path = temp_file.name
49
 
50
+ results = classifier(temp_audio_path)
51
  os.unlink(temp_audio_path) # Clean up the temp file
52
 
53
+ # The transformers pipeline returns a list of dicts
54
+ # Example: [{'score': 0.99, 'label': 'happy'}, {'score': 0.01, 'label': 'sad'}]
55
+ # We will return this directly
56
  return JSONResponse(content={"data": results})
57
 
58
  except Exception as e:
59
+ # Clean up the temp file if it exists even after an error
60
+ if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
61
+ os.unlink(temp_audio_path)
62
  return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500)
63
 
64
+ # --- Gradio UI (for demonstration on the Space's page) ---
65
+ def gradio_predict_wrapper(audio_file_path):
66
+ if classifier is None: return {"error": f"Model is not loaded: {model_load_error}"}
67
+ if audio_file_path is None: return {"error": "Please provide an audio file."}
68
+
69
+ try:
70
+ results = classifier(audio_file_path, top_k=5)
71
+ # Format for Gradio's Label component
72
+ return {item['label']: item['score'] for item in results}
73
+ except Exception as e:
74
+ return {"error": str(e)}
75
 
76
  gradio_interface = gr.Interface(
77
  fn=gradio_predict_wrapper,
78
  inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
79
  outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
80
  title="Audio Emotion Detector",
81
+ description="This UI is for direct demonstration. The primary API for websites is at /api/predict/",
82
  allow_flagging="never"
83
  )
84
 
85
+ # Mount the Gradio UI onto a subpath of our FastAPI app
86
+ app = gr.mount_gradio_app(app, gradio_interface, path="/gradio")
87
+
88
+ # The Uvicorn server launch command (used by Hugging Face Spaces)
89
+ # This is the ONLY launch command needed.
90
+ if __name__ == "__main__":
91
+ uvicorn.run(app, host="0.0.0.0", port=7860)