Spaces:
Configuration error
Configuration error
# predict.py | |
import torchaudio | |
import demucs.separate | |
from fastapi import FastAPI, UploadFile, HTTPException, status | |
from fastapi.responses import FileResponse | |
import shutil | |
import os | |
import uuid | |
import tempfile | |
import logging | |
app = FastAPI() | |
logging.basicConfig(level=logging.INFO) | |
STEMS = ["vocals", "drums", "bass", "other"] | |
async def predict(audio: UploadFile): | |
# Validate file type | |
if not audio.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): | |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported file type.") | |
# Use a unique temp directory for each request | |
with tempfile.TemporaryDirectory() as tmpdir: | |
audio_path = os.path.join(tmpdir, f"{uuid.uuid4()}_{audio.filename}") | |
try: | |
# Save uploaded file | |
with open(audio_path, "wb") as f: | |
shutil.copyfileobj(audio.file, f) | |
# Run Demucs separation | |
output_dir = os.path.join(tmpdir, "separated") | |
os.makedirs(output_dir, exist_ok=True) | |
demucs.separate.main(["--mp3", "-n", "htdemucs", "-d", "cpu", "-o", output_dir, audio_path]) | |
# Find output stems | |
base = os.path.splitext(os.path.basename(audio_path))[0] | |
stem_files = {} | |
for stem in STEMS: | |
path = os.path.join(output_dir, "htdemucs", base, f"{stem}.mp3") | |
if not os.path.exists(path): | |
raise HTTPException(status_code=500, detail=f"Stem {stem} not found.") | |
stem_files[stem] = path | |
# Optionally, return as downloadable files (example: vocals only) | |
# return FileResponse(stem_files["vocals"], media_type="audio/mpeg", filename=f"{base}_vocals.mp3") | |
# Or return all stems as file paths (for demo; in prod, upload to S3/CDN and return URLs) | |
return {"stems": stem_files} | |
except Exception as e: | |
logging.exception("Error during separation") | |
raise HTTPException(status_code=500, detail=str(e)) | |