import gradio as gr from transformers import pipeline import soundfile as sf import os import base64 import tempfile from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import uvicorn from fastapi.middleware.cors import CORSMiddleware # <--- 1. ADD THIS IMPORT # --- Load Model --- try: classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er") except Exception as e: classifier = None model_load_error = str(e) else: model_load_error = None # --- FastAPI App --- app = FastAPI() # --- 2. ADD THIS ENTIRE BLOCK --- # This block adds the CORS middleware to allow your WordPress site to make requests. app.add_middleware( CORSMiddleware, allow_origins=["https://tknassetshub.io"], # This gives your specific domain permission. allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------- # This is our dedicated, robust API endpoint @app.post("/api/predict/") async def predict_emotion_api(request: Request): if classifier is None: return JSONResponse(content={"error": f"Model is not loaded: {model_load_error}"}, status_code=503) try: body = await request.json() # The JS FileReader sends a string like "data:audio/wav;base64,AABBCC..." base64_with_prefix = body.get("data") if not base64_with_prefix: return JSONResponse(content={"error": "Missing 'data' field in request body."}, status_code=400) # Robustly strip the prefix to get the pure base64 data try: header, encoded = base64_with_prefix.split(",", 1) audio_data = base64.b64decode(encoded) except (ValueError, TypeError): return JSONResponse(content={"error": "Invalid base64 data format. Please send the full data URI."}, status_code=400) # Write to a temporary file for the pipeline to process with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: temp_file.write(audio_data) temp_audio_path = temp_file.name results = classifier(temp_audio_path) os.unlink(temp_audio_path) # Clean up the temp file # Return a successful response with the data return JSONResponse(content={"data": results}) except Exception as e: # Clean up the temp file if it exists even after an error if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path): os.unlink(temp_audio_path) return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500) # --- Gradio UI (for demonstration on the Space's page) --- def gradio_predict_wrapper(audio_file_path): if classifier is None: return {"error": f"Model is not loaded: {model_load_error}"} if audio_file_path is None: return {"error": "Please provide an audio file."} try: results = classifier(audio_file_path, top_k=5) return {item['label']: item['score'] for item in results} except Exception as e: return {"error": str(e)} gradio_interface = gr.Interface( fn=gradio_predict_wrapper, inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"), outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"), title="Audio Emotion Detector", description="This UI is for direct demonstration. The primary API for websites is at /api/predict/", allow_flagging="never" ) # Mount the Gradio UI onto a subpath of our FastAPI app app = gr.mount_gradio_app(app, gradio_interface, path="/gradio") # The Uvicorn server launch command (used by Hugging Face Spaces) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)