File size: 2,138 Bytes
519d358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# predict.py
import torchaudio
import demucs.separate
from fastapi import FastAPI, UploadFile, HTTPException, status
from fastapi.responses import FileResponse
import shutil
import os
import uuid
import tempfile
import logging

app = FastAPI()
logging.basicConfig(level=logging.INFO)

STEMS = ["vocals", "drums", "bass", "other"]

@app.post("/predict")
async def predict(audio: UploadFile):
    # Validate file type
    if not audio.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")):
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unsupported file type.")

    # Use a unique temp directory for each request
    with tempfile.TemporaryDirectory() as tmpdir:
        audio_path = os.path.join(tmpdir, f"{uuid.uuid4()}_{audio.filename}")
        try:
            # Save uploaded file
            with open(audio_path, "wb") as f:
                shutil.copyfileobj(audio.file, f)

            # Run Demucs separation
            output_dir = os.path.join(tmpdir, "separated")
            os.makedirs(output_dir, exist_ok=True)
            demucs.separate.main(["--mp3", "-n", "htdemucs", "-d", "cpu", "-o", output_dir, audio_path])

            # Find output stems
            base = os.path.splitext(os.path.basename(audio_path))[0]
            stem_files = {}
            for stem in STEMS:
                path = os.path.join(output_dir, "htdemucs", base, f"{stem}.mp3")
                if not os.path.exists(path):
                    raise HTTPException(status_code=500, detail=f"Stem {stem} not found.")
                stem_files[stem] = path

            # Optionally, return as downloadable files (example: vocals only)
            # return FileResponse(stem_files["vocals"], media_type="audio/mpeg", filename=f"{base}_vocals.mp3")

            # Or return all stems as file paths (for demo; in prod, upload to S3/CDN and return URLs)
            return {"stems": stem_files}

        except Exception as e:
            logging.exception("Error during separation")
            raise HTTPException(status_code=500, detail=str(e))