Aiaudio / app.py
Athspi's picture
Create app.py
526b24d verified
raw
history blame
22.3 kB
import os
import uuid
import tempfile
import logging
import asyncio
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
import io # For zip file in memory
import zipfile
# --- Basic Editing Imports ---
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
# --- AI & Advanced Audio Imports ---
try:
import torch
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible
# Specific model imports might be needed depending on the chosen approach
import soundfile as sf
import numpy as np
import librosa # For resampling if needed
print("AI and advanced audio libraries loaded.")
except ImportError as e:
print(f"Error importing AI/Audio libraries: {e}")
print("Ensure torch, transformers, soundfile, librosa are installed.")
print("AI features will be unavailable.")
torch = None
pipeline = None
sf = None
np = None
librosa = None
# --- Configuration & Setup ---
TEMP_DIR = tempfile.gettempdir()
os.makedirs(TEMP_DIR, exist_ok=True)
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Global Variables for Loaded Models ---
# Use dictionaries to potentially hold multiple models of each type later
enhancement_pipelines: Dict[str, Any] = {}
separation_models: Dict[str, Any] = {} # Might store pipeline or model/processor pair
# Target sampling rates for models (check model cards on Hugging Face!)
ENHANCEMENT_SR = 16000 # Example for speechbrain/sepformer
DEMUCS_SR = 44100 # Demucs default
# --- Helper Functions ---
def cleanup_file(file_path: str):
"""Safely remove a file."""
try:
if file_path and os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Cleaned up temporary file: {file_path}")
except Exception as e:
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
"""Saves an uploaded file to a temporary location and returns the path."""
# Generate a unique temporary file path
_, file_extension = os.path.splitext(upload_file.filename)
if not file_extension: file_extension = ".wav" # Default if no extension
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
try:
with open(temp_file_path, "wb") as buffer:
while content := await upload_file.read(1024 * 1024): buffer.write(content)
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
return temp_file_path
except Exception as e:
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
cleanup_file(temp_file_path)
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
finally:
await upload_file.close()
# --- Audio Loading/Saving for AI Models ---
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]:
"""Loads audio using soundfile, converts to mono float32, optionally resamples."""
try:
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
# Convert to mono if stereo
if audio.ndim > 1 and audio.shape[1] > 1:
# Simple averaging for mono conversion
audio = np.mean(audio, axis=1)
logger.info("Converted audio to mono")
# Resample if necessary
if target_sr and orig_sr != target_sr:
if librosa is None:
raise RuntimeError("Librosa is required for resampling but not installed.")
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
logger.info(f"Resampled audio shape: {audio.shape}")
current_sr = target_sr
else:
current_sr = orig_sr
return audio, current_sr
except Exception as e:
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")
def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str:
"""Saves a NumPy audio array to a temporary file."""
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
output_path = os.path.join(TEMP_DIR, output_filename)
try:
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
# Ensure data is float32 for common formats like wav/flac, pydub handles mp3 etc.
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.float32)
# Use soundfile for lossless formats
if output_format.lower() in ['wav', 'flac']:
sf.write(output_path, audio_data, sampling_rate, format=output_format.upper())
else:
# For lossy formats like mp3, use pydub after converting numpy array
# Convert numpy array [-1.0, 1.0] float32 to pydub segment
# Scale to 16-bit integer range for pydub if needed
audio_int16 = (audio_data * 32767).astype(np.int16)
segment = AudioSegment(
audio_int16.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_int16.dtype.itemsize,
channels=1 # Assuming mono output from AI models for now
)
segment.export(output_path, format=output_format)
return output_path
except Exception as e:
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
cleanup_file(output_path)
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
# --- Synchronous AI Inference Functions (to be run in threads) ---
def _run_enhancement_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray:
"""Synchronous wrapper for enhancement model inference."""
if not model_pipeline: raise ValueError("Enhancement model not loaded")
try:
logger.info(f"Running speech enhancement (input shape: {audio_data.shape}, SR: {sampling_rate})...")
# Pipeline usage depends heavily on the specific pipeline
# Example for a hypothetical 'audio-enhancement' pipeline:
result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate})
enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
return enhanced_audio
except Exception as e:
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
raise # Re-raise to be caught by the async wrapper
def _run_separation_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]:
"""Synchronous wrapper for source separation model inference."""
if not model_pipeline: raise ValueError("Separation model not loaded")
try:
logger.info(f"Running source separation (input shape: {audio_data.shape}, SR: {sampling_rate})...")
# Usage depends on the separation model/pipeline
# Example for a hypothetical 'audio-source-separation' pipeline:
# Note: Actual Demucs might need different handling (e.g., direct model call)
# result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate})
# Manual example closer to raw Demucs model (if not using pipeline)
# Assuming `separation_models['demucs']` holds the loaded Demucs model instance
model = separation_models.get('demucs')
if not model: raise ValueError("Demucs model not loaded correctly")
# Demucs expects stereo input in shape (batch, channels, samples)
# Convert mono to stereo if needed, add batch dim
if audio_data.ndim == 1:
audio_data = np.stack([audio_data, audio_data], axis=0) # Create stereo from mono
audio_tensor = torch.tensor(audio_data).unsqueeze(0) # Add batch dimension
# Move to GPU if available and model is on GPU
device = next(model.parameters()).device
audio_tensor = audio_tensor.to(device)
with torch.no_grad():
sources = model(audio_tensor)[0] # Output shape (stems, channels, samples)
# Detach, move to CPU, convert to numpy
sources_np = sources.detach().cpu().numpy()
# Convert back to mono for simplicity (average channels)
stems = {
'drums': np.mean(sources_np[0], axis=0),
'bass': np.mean(sources_np[1], axis=0),
'other': np.mean(sources_np[2], axis=0),
'vocals': np.mean(sources_np[3], axis=0),
}
# Important: The order (drums, bass, other, vocals) is specific to Demucs v3/v4 default model
logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
return stems
except Exception as e:
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
raise
# --- Model Loading Function ---
def load_hf_models():
"""Loads Hugging Face models at startup."""
global enhancement_pipelines, separation_models
if torch is None or pipeline is None:
logger.warning("Torch or Transformers not available. Skipping Hugging Face model loading.")
return
# --- Load Enhancement Model ---
# Using speechbrain/sepformer-whamr-enhancement via pipeline (check HF for exact pipeline task)
# Or load model/processor manually if no direct pipeline exists
enhancement_model_id = "speechbrain/sepformer-whamr-enhancement" # Example ID
try:
logger.info(f"Loading enhancement model: {enhancement_model_id}...")
# Use appropriate task, might be 'audio-enhancement', 'audio-classification' with custom logic, or manual loading
# If no pipeline, load manually:
# enhancement_processor = AutoProcessor.from_pretrained(...)
# enhancement_model = AutoModel...from_pretrained(...)
# enhancement_pipelines['speechbrain_sepformer'] = {"processor": enhancement_processor, "model": enhancement_model}
# For now, let's assume a placeholder pipeline exists or skip if complex
# enhancement_pipelines['speechbrain_sepformer'] = pipeline("audio-enhancement", model=enhancement_model_id)
logger.warning(f"Skipping load for {enhancement_model_id} - requires specific pipeline or manual setup.")
except Exception as e:
logger.error(f"Failed to load enhancement model '{enhancement_model_id}': {e}", exc_info=False)
# --- Load Separation Model (Demucs) ---
# Demucs is often used directly, not via a standard HF pipeline task
separation_model_id = "facebook/demucs" # Or specific variant like facebook/hybrid_demucs
try:
logger.info(f"Loading separation model: {separation_model_id}...")
# Demucs usually requires loading the model directly
# Using AutoModel might work for some variants if configured correctly in HF hub
# separation_models['demucs'] = AutoModel.from_pretrained(separation_model_id) # Check if this works
# More typically, you might need to install the 'demucs' package itself: pip install -U demucs
# import demucs.separate
# model = demucs.apply.load_model(separation_model_id or demucs.pretrained.DEFAULT_MODEL) # Using demucs package
# separation_models['demucs'] = model
# For now, simulate loading failure as direct AutoModel might not work
raise NotImplementedError("Demucs loading typically requires the 'demucs' package or specific manual loading.")
logger.info(f"Separation model '{separation_model_id}' loaded.")
except Exception as e:
logger.error(f"Failed to load separation model '{separation_model_id}': {e}", exc_info=False)
logger.warning("Note: Demucs loading often requires 'pip install demucs' and specific loading code, not just AutoModel.")
# --- FastAPI App and Endpoints ---
app = FastAPI(
title="AI Audio Editor API",
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and HF model dependencies.",
version="2.0.0",
)
@app.on_event("startup")
async def startup_event():
"""Load models when the application starts."""
logger.info("Application startup: Loading AI models...")
# Running model loading in a separate thread to avoid blocking startup completely
# although startup will wait for this thread to finish.
# Consider truly background loading if startup time is critical.
await asyncio.to_thread(load_hf_models)
logger.info("Model loading process finished (check logs for success/failure).")
# --- Basic Editing Endpoints (Mostly Unchanged) ---
@app.get("/", tags=["General"])
def read_root():
"""Root endpoint providing a welcome message and available features."""
features = ["/trim", "/concat", "/volume", "/convert"]
ai_features = []
if enhancement_pipelines: ai_features.append("/enhance")
if separation_models: ai_features.append("/separate")
return {
"message": "Welcome to the AI Audio Editor API.",
"basic_features": features,
"ai_features": ai_features if ai_features else "None available (models might have failed to load)",
"notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)."
}
# /trim, /concat, /volume, /convert endpoints remain largely the same as before
# Ensure they use the updated save_upload_file and cleanup logic
# (Code for these endpoints omitted for brevity - refer to previous example)
# ... Add /trim, /concat, /volume, /convert endpoints here ...
# --- AI Endpoints ---
@app.post("/enhance", tags=["AI Editing"])
async def enhance_speech(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
model_id: str = Query("speechbrain_sepformer", description="ID of the enhancement model to use (if multiple loaded)."),
output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).")
):
"""Enhances speech audio using a pre-loaded AI model (experimental)."""
if torch is None or sf is None or np is None:
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
if model_id not in enhancement_pipelines:
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_id}' is not loaded or available.")
logger.info(f"Enhance request: file='{file.filename}', model='{model_id}', format='{output_format}'")
input_path = await save_upload_file(file, prefix="enhance_in_")
background_tasks.add_task(cleanup_file, input_path)
output_path = None # Define output_path before try block
try:
# Load audio, ensure correct SR for the model
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
# Run inference in a separate thread
logger.info("Submitting enhancement task to background thread...")
model_pipeline = enhancement_pipelines[model_id] # Get the specific loaded pipeline/model
enhanced_audio = await asyncio.to_thread(
_run_enhancement_sync, model_pipeline, audio_data, current_sr
)
logger.info("Enhancement task completed.")
# Save the result
output_path = save_hf_audio(enhanced_audio, current_sr, output_format) # Use current_sr (which is target_sr)
background_tasks.add_task(cleanup_file, output_path)
return FileResponse(
path=output_path,
media_type=f"audio/{output_format}",
filename=f"enhanced_{file.filename}"
)
except Exception as e:
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
if output_path: cleanup_file(output_path) # Cleanup if error occurred after output started saving
if isinstance(e, HTTPException): raise e
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
@app.post("/separate", tags=["AI Editing"])
async def separate_sources(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Music audio file to separate into stems."),
model_id: str = Query("demucs", description="ID of the separation model to use."),
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
):
"""Separates music into stems (vocals, drums, bass, other) using Demucs (experimental). Returns a ZIP archive."""
if torch is None or sf is None or np is None:
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
if model_id not in separation_models:
raise HTTPException(status_code=503, detail=f"Separation model '{model_id}' is not loaded or available.")
valid_stems = {'vocals', 'drums', 'bass', 'other'}
requested_stems = set(s.lower() for s in stems)
if not requested_stems.issubset(valid_stems):
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are: {', '.join(valid_stems)}")
logger.info(f"Separate request: file='{file.filename}', model='{model_id}', stems={requested_stems}, format='{output_format}'")
input_path = await save_upload_file(file, prefix="separate_in_")
background_tasks.add_task(cleanup_file, input_path)
stem_output_paths: Dict[str, str] = {}
zip_buffer = None
try:
# Load audio, ensure correct SR for the model (Demucs uses 44.1kHz)
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
# Run inference in a separate thread
logger.info("Submitting separation task to background thread...")
model = separation_models[model_id] # Get the specific loaded model
all_separated_stems = await asyncio.to_thread(
_run_separation_sync, model, audio_data, current_sr
)
logger.info("Separation task completed.")
# --- Create ZIP file in memory ---
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
# Save only the requested stems
for stem_name in requested_stems:
if stem_name in all_separated_stems:
stem_data = all_separated_stems[stem_name]
# Save stem temporarily to disk first (needed for pydub/sf.write)
stem_path = save_hf_audio(stem_data, current_sr, output_format)
stem_output_paths[stem_name] = stem_path # Keep track for cleanup
background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
# Add the saved stem file to the ZIP archive
archive_name = f"{stem_name}_{os.path.basename(input_path)}.{output_format}"
zipf.write(stem_path, arcname=archive_name)
logger.info(f"Added '{archive_name}' to ZIP.")
else:
logger.warning(f"Requested stem '{stem_name}' not found in model output.")
zip_buffer.seek(0) # Rewind buffer pointer
# Return the ZIP file
zip_filename = f"separated_stems_{os.path.splitext(file.filename)[0]}.zip"
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
)
except Exception as e:
logger.error(f"Error during separation operation: {e}", exc_info=True)
# Cleanup any stems that were saved before zipping failed
for path in stem_output_paths.values():
cleanup_file(path)
if zip_buffer: zip_buffer.close() # Close memory buffer
if isinstance(e, HTTPException): raise e
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
# --- How to Run ---
# 1. Ensure FFmpeg is installed and accessible.
# 2. Save this code as `app.py`.
# 3. Create `requirements.txt` (as shown above).
# 4. Install dependencies: `pip install -r requirements.txt` (This can take time!)
# 5. Run the FastAPI server: `uvicorn app:app --reload --host 0.0.0.0`
# (Use --host 0.0.0.0 for external/Docker access. --reload is optional)
#
# --- WARNING ---
# - AI models require SIGNIFICANT RAM and CPU/GPU. Inference can be SLOW.
# - The first run will download models, which can take a long time and lots of disk space.
# - Ensure the specific model IDs used are correct and compatible with HF libraries.
# - Model loading at startup might fail if dependencies are missing or resources are insufficient. Check logs!
#
# --- Example Usage (using curl) ---
#
# **Enhance:** (Enhance noisy_speech.wav)
# curl -X POST "http://127.0.0.1:8000/enhance?model_id=speechbrain_sepformer" \
# -F "file=@noisy_speech.wav" \
# -F "output_format=wav" \
# --output enhanced_speech.wav
#
# **Separate:** (Separate vocals and drums from music.mp3)
# curl -X POST "http://127.0.0.1:8000/separate?model_id=demucs" \
# -F "[email protected]" \
# -F "stems=vocals" \
# -F "stems=drums" \
# -F "output_format=mp3" \
# --output separated_stems.zip