emotune-api / app.py
srisuriyas's picture
Update app.py
e67d2cf verified
raw
history blame
1.04 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
import uvicorn
import tempfile
import torchaudio
app = FastAPI()
# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
pipe = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Save uploaded file to a temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
# Load and preprocess audio
waveform, sample_rate = torchaudio.load(tmp_path)
# Get prediction
result = pipe(tmp_path)
# Get top prediction label
top_emotion = result[0]["label"].lower()
return {"emotion": top_emotion}
except Exception as e:
return {"error": str(e)}