audio / predict.py
PreciousMposa's picture
Upload 107 files
519d358 verified
# predict.py
import torchaudio
import demucs.separate
from fastapi import FastAPI, UploadFile, HTTPException, status
from fastapi.responses import FileResponse
import shutil
import os
import uuid
import tempfile
import logging
app = FastAPI()
logging.basicConfig(level=logging.INFO)
STEMS = ["vocals", "drums", "bass", "other"]
@app.post("/predict")
async def predict(audio: UploadFile):
# Validate file type
if not audio.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported file type.")
# Use a unique temp directory for each request
with tempfile.TemporaryDirectory() as tmpdir:
audio_path = os.path.join(tmpdir, f"{uuid.uuid4()}_{audio.filename}")
try:
# Save uploaded file
with open(audio_path, "wb") as f:
shutil.copyfileobj(audio.file, f)
# Run Demucs separation
output_dir = os.path.join(tmpdir, "separated")
os.makedirs(output_dir, exist_ok=True)
demucs.separate.main(["--mp3", "-n", "htdemucs", "-d", "cpu", "-o", output_dir, audio_path])
# Find output stems
base = os.path.splitext(os.path.basename(audio_path))[0]
stem_files = {}
for stem in STEMS:
path = os.path.join(output_dir, "htdemucs", base, f"{stem}.mp3")
if not os.path.exists(path):
raise HTTPException(status_code=500, detail=f"Stem {stem} not found.")
stem_files[stem] = path
# Optionally, return as downloadable files (example: vocals only)
# return FileResponse(stem_files["vocals"], media_type="audio/mpeg", filename=f"{base}_vocals.mp3")
# Or return all stems as file paths (for demo; in prod, upload to S3/CDN and return URLs)
return {"stems": stem_files}
except Exception as e:
logging.exception("Error during separation")
raise HTTPException(status_code=500, detail=str(e))