Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
import soundfile as sf | |
import os | |
import base64 | |
import tempfile | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse | |
import uvicorn | |
from fastapi.middleware.cors import CORSMiddleware # <--- 1. ADD THIS IMPORT | |
# --- Load Model --- | |
try: | |
classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er") | |
except Exception as e: | |
classifier = None | |
model_load_error = str(e) | |
else: | |
model_load_error = None | |
# --- FastAPI App --- | |
app = FastAPI() | |
# --- 2. ADD THIS ENTIRE BLOCK --- | |
# This block adds the CORS middleware to allow your WordPress site to make requests. | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["https://tknassetshub.io"], # This gives your specific domain permission. | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --------------------------------- | |
# This is our dedicated, robust API endpoint | |
async def predict_emotion_api(request: Request): | |
if classifier is None: | |
return JSONResponse(content={"error": f"Model is not loaded: {model_load_error}"}, status_code=503) | |
try: | |
body = await request.json() | |
# The JS FileReader sends a string like "data:audio/wav;base64,AABBCC..." | |
base64_with_prefix = body.get("data") | |
if not base64_with_prefix: | |
return JSONResponse(content={"error": "Missing 'data' field in request body."}, status_code=400) | |
# Robustly strip the prefix to get the pure base64 data | |
try: | |
header, encoded = base64_with_prefix.split(",", 1) | |
audio_data = base64.b64decode(encoded) | |
except (ValueError, TypeError): | |
return JSONResponse(content={"error": "Invalid base64 data format. Please send the full data URI."}, status_code=400) | |
# Write to a temporary file for the pipeline to process | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
temp_file.write(audio_data) | |
temp_audio_path = temp_file.name | |
results = classifier(temp_audio_path) | |
os.unlink(temp_audio_path) # Clean up the temp file | |
# Return a successful response with the data | |
return JSONResponse(content={"data": results}) | |
except Exception as e: | |
# Clean up the temp file if it exists even after an error | |
if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path): | |
os.unlink(temp_audio_path) | |
return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500) | |
# --- Gradio UI (for demonstration on the Space's page) --- | |
def gradio_predict_wrapper(audio_file_path): | |
if classifier is None: return {"error": f"Model is not loaded: {model_load_error}"} | |
if audio_file_path is None: return {"error": "Please provide an audio file."} | |
try: | |
results = classifier(audio_file_path, top_k=5) | |
return {item['label']: item['score'] for item in results} | |
except Exception as e: | |
return {"error": str(e)} | |
gradio_interface = gr.Interface( | |
fn=gradio_predict_wrapper, | |
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"), | |
outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"), | |
title="Audio Emotion Detector", | |
description="This UI is for direct demonstration. The primary API for websites is at /api/predict/", | |
allow_flagging="never" | |
) | |
# Mount the Gradio UI onto a subpath of our FastAPI app | |
app = gr.mount_gradio_app(app, gradio_interface, path="/gradio") | |
# The Uvicorn server launch command (used by Hugging Face Spaces) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |