# predict.py import torchaudio import demucs.separate from fastapi import FastAPI, UploadFile, HTTPException, status from fastapi.responses import FileResponse import shutil import os import uuid import tempfile import logging app = FastAPI() logging.basicConfig(level=logging.INFO) STEMS = ["vocals", "drums", "bass", "other"] @app.post("/predict") async def predict(audio: UploadFile): # Validate file type if not audio.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported file type.") # Use a unique temp directory for each request with tempfile.TemporaryDirectory() as tmpdir: audio_path = os.path.join(tmpdir, f"{uuid.uuid4()}_{audio.filename}") try: # Save uploaded file with open(audio_path, "wb") as f: shutil.copyfileobj(audio.file, f) # Run Demucs separation output_dir = os.path.join(tmpdir, "separated") os.makedirs(output_dir, exist_ok=True) demucs.separate.main(["--mp3", "-n", "htdemucs", "-d", "cpu", "-o", output_dir, audio_path]) # Find output stems base = os.path.splitext(os.path.basename(audio_path))[0] stem_files = {} for stem in STEMS: path = os.path.join(output_dir, "htdemucs", base, f"{stem}.mp3") if not os.path.exists(path): raise HTTPException(status_code=500, detail=f"Stem {stem} not found.") stem_files[stem] = path # Optionally, return as downloadable files (example: vocals only) # return FileResponse(stem_files["vocals"], media_type="audio/mpeg", filename=f"{base}_vocals.mp3") # Or return all stems as file paths (for demo; in prod, upload to S3/CDN and return URLs) return {"stems": stem_files} except Exception as e: logging.exception("Error during separation") raise HTTPException(status_code=500, detail=str(e))