File size: 3,786 Bytes
cae86cc
 
 
 
5cde27f
 
7a18e70
 
 
07df6b8
cae86cc
7a18e70
0ea75df
6608513
0ea75df
6608513
7a18e70
 
 
cae86cc
c2cb49b
7a18e70
 
07df6b8
 
 
 
 
 
 
 
 
 
 
 
c2cb49b
7a18e70
 
5cde27f
1bbfae6
5cde27f
 
7a18e70
1bbfae6
 
7a18e70
1bbfae6
 
 
 
 
 
 
 
c2cb49b
1bbfae6
c2cb49b
7a18e70
5cde27f
 
7a18e70
c2cb49b
1bbfae6
 
cd9d52d
1bbfae6
7a18e70
5cde27f
c2cb49b
 
 
1bbfae6
 
c2cb49b
 
 
 
 
 
 
 
 
cae86cc
7a18e70
1bbfae6
7a18e70
 
 
c2cb49b
7a18e70
cae86cc
 
c2cb49b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import pipeline
import soundfile as sf
import os
import base64
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
from fastapi.middleware.cors import CORSMiddleware # <--- 1. ADD THIS IMPORT

# --- Load Model ---
try:
    classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
except Exception as e:
    classifier = None
    model_load_error = str(e)
else:
    model_load_error = None

# --- FastAPI App ---
app = FastAPI()

# --- 2. ADD THIS ENTIRE BLOCK ---
# This block adds the CORS middleware to allow your WordPress site to make requests.
app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://tknassetshub.io"],  # This gives your specific domain permission.
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
# ---------------------------------


# This is our dedicated, robust API endpoint
@app.post("/api/predict/")
async def predict_emotion_api(request: Request):
    if classifier is None:
        return JSONResponse(content={"error": f"Model is not loaded: {model_load_error}"}, status_code=503)
    
    try:
        body = await request.json()
        # The JS FileReader sends a string like "data:audio/wav;base64,AABBCC..."
        base64_with_prefix = body.get("data")

        if not base64_with_prefix:
            return JSONResponse(content={"error": "Missing 'data' field in request body."}, status_code=400)

        # Robustly strip the prefix to get the pure base64 data
        try:
            header, encoded = base64_with_prefix.split(",", 1)
            audio_data = base64.b64decode(encoded)
        except (ValueError, TypeError):
             return JSONResponse(content={"error": "Invalid base64 data format. Please send the full data URI."}, status_code=400)

        # Write to a temporary file for the pipeline to process
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
            temp_file.write(audio_data)
            temp_audio_path = temp_file.name

        results = classifier(temp_audio_path)
        os.unlink(temp_audio_path) # Clean up the temp file

        # Return a successful response with the data
        return JSONResponse(content={"data": results})

    except Exception as e:
        # Clean up the temp file if it exists even after an error
        if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path):
            os.unlink(temp_audio_path)
        return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500)

# --- Gradio UI (for demonstration on the Space's page) ---
def gradio_predict_wrapper(audio_file_path):
    if classifier is None: return {"error": f"Model is not loaded: {model_load_error}"}
    if audio_file_path is None: return {"error": "Please provide an audio file."}
    try:
        results = classifier(audio_file_path, top_k=5)
        return {item['label']: item['score'] for item in results}
    except Exception as e:
        return {"error": str(e)}

gradio_interface = gr.Interface(
    fn=gradio_predict_wrapper,
    inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"),
    outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"),
    title="Audio Emotion Detector",
    description="This UI is for direct demonstration. The primary API for websites is at /api/predict/",
    allow_flagging="never"
)

# Mount the Gradio UI onto a subpath of our FastAPI app
app = gr.mount_gradio_app(app, gradio_interface, path="/gradio")

# The Uvicorn server launch command (used by Hugging Face Spaces)
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)