srisuriyas commited on
Commit
e67d2cf
·
verified ·
1 Parent(s): 1b47516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -1,45 +1,39 @@
1
  from fastapi import FastAPI, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  from transformers import pipeline
5
  import uvicorn
6
  import tempfile
 
7
 
8
- # Initialize FastAPI
9
  app = FastAPI()
10
 
11
- # Enable CORS for all origins (so Render or any client can access it)
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
15
- allow_credentials=True,
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
- # Load the pretrained speech emotion recognition pipeline
21
- emotion_pipeline = pipeline(
22
- "audio-classification",
23
- model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
24
- )
25
-
26
- # Health check route
27
- @app.get("/")
28
- def read_root():
29
- return {"message": "HF Space is live!"}
30
 
31
- # Predict route
32
  @app.post("/predict")
33
- async def predict_emotion(file: UploadFile = File(...)):
34
  try:
35
- # Save the uploaded audio file to a temporary location
36
- with tempfile.NamedTemporaryFile(delete=False) as tmp:
37
  tmp.write(await file.read())
38
  tmp_path = tmp.name
39
 
40
- # Run emotion prediction
41
- result = emotion_pipeline(tmp_path)
42
- top_emotion = result[0]['label']
 
 
 
 
 
43
 
44
  return {"emotion": top_emotion}
45
 
 
1
  from fastapi import FastAPI, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from transformers import pipeline
4
  import uvicorn
5
  import tempfile
6
+ import torchaudio
7
 
 
8
  app = FastAPI()
9
 
10
+ # Allow CORS
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
 
14
  allow_methods=["*"],
15
  allow_headers=["*"],
16
  )
17
 
18
+ # Load model
19
+ pipe = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er")
 
 
 
 
 
 
 
 
20
 
 
21
  @app.post("/predict")
22
+ async def predict(file: UploadFile = File(...)):
23
  try:
24
+ # Save uploaded file to a temp file
25
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
26
  tmp.write(await file.read())
27
  tmp_path = tmp.name
28
 
29
+ # Load and preprocess audio
30
+ waveform, sample_rate = torchaudio.load(tmp_path)
31
+
32
+ # Get prediction
33
+ result = pipe(tmp_path)
34
+
35
+ # Get top prediction label
36
+ top_emotion = result[0]["label"].lower()
37
 
38
  return {"emotion": top_emotion}
39