Diggz10's picture
Update app.py
07df6b8 verified
import gradio as gr
from transformers import pipeline
import soundfile as sf
import os
import base64
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
from fastapi.middleware.cors import CORSMiddleware # <--- 1. ADD THIS IMPORT
# --- Load Model ---
try:
classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
except Exception as e:
classifier = None
model_load_error = str(e)
else:
model_load_error = None
# --- FastAPI App ---
app = FastAPI()
# --- 2. ADD THIS ENTIRE BLOCK ---
# This block adds the CORS middleware to allow your WordPress site to make requests.
app.add_middleware(
CORSMiddleware,
allow_origins=["https://tknassetshub.io"], # This gives your specific domain permission.
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------
# This is our dedicated, robust API endpoint
@app.post("/api/predict/")
async def predict_emotion_api(request: Request):
if classifier is None:
return JSONResponse(content={"error": f"Model is not loaded: {model_load_error}"}, status_code=503)
try:
body = await request.json()
# The JS FileReader sends a string like "data:audio/wav;base64,AABBCC..."
base64_with_prefix = body.get("data")
if not base64_with_prefix:
return JSONResponse(content={"error": "Missing 'data' field in request body."}, status_code=400)
# Robustly strip the prefix to get the pure base64 data
try:
header, encoded = base64_with_prefix.split(",", 1)
audio_data = base64.b64decode(encoded)
except (ValueError, TypeError):
return JSONResponse(content={"error": "Invalid base64 data format. Please send the full data URI."}, status_code=400)
# Write to a temporary file for the pipeline to process
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(audio_data)
temp_audio_path = temp_file.name
results = classifier(temp_audio_path)
os.unlink(temp_audio_path) # Clean up the temp file
# Return a successful response with the data
return JSONResponse(content={"data": results})
except Exception as e:
# Clean up the temp file if it exists even after an error
if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500)
# --- Gradio UI (for demonstration on the Space's page) ---
def gradio_predict_wrapper(audio_file_path):
if classifier is None: return {"error": f"Model is not loaded: {model_load_error}"}
if audio_file_path is None: return {"error": "Please provide an audio file."}
try:
results = classifier(audio_file_path, top_k=5)
return {item['label']: item['score'] for item in results}
except Exception as e:
return {"error": str(e)}
gradio_interface = gr.Interface(
fn=gradio_predict_wrapper,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
title="Audio Emotion Detector",
description="This UI is for direct demonstration. The primary API for websites is at /api/predict/",
allow_flagging="never"
)
# Mount the Gradio UI onto a subpath of our FastAPI app
app = gr.mount_gradio_app(app, gradio_interface, path="/gradio")
# The Uvicorn server launch command (used by Hugging Face Spaces)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)