Aiaudio / app.py
Athspi's picture
Update app.py
3f784c4 verified
raw
history blame
23.2 kB
import os
import uuid
import tempfile
import logging
import asyncio
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
import io
import zipfile
# --- Basic Editing Imports ---
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
# --- AI & Advanced Audio Imports ---
try:
import torch
# Transformers only needed if using HF pipelines directly, not for speechbrain/demucs manual loading
# from transformers import pipeline
import soundfile as sf
import numpy as np
import librosa
# Specific Model Libraries
import speechbrain.pretrained
import demucs.separate
import demucs.apply
print("AI and advanced audio libraries loaded.")
except ImportError as e:
print(f"Error importing AI/Audio libraries: {e}")
print("Ensure torch, soundfile, librosa, speechbrain, demucs are installed.")
print("AI features will be unavailable.")
torch = None
sf = None
np = None
librosa = None
speechbrain = None
demucs = None
# --- Configuration & Setup ---
TEMP_DIR = tempfile.gettempdir()
os.makedirs(TEMP_DIR, exist_ok=True)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Global Variables for Loaded Models ---
# Use consistent keys for storing/retrieving models
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
# Choose a default Demucs model (htdemucs is good quality)
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
enhancement_models: Dict[str, Any] = {}
separation_models: Dict[str, Any] = {}
# Target sampling rates (confirm from model specifics if necessary)
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
DEMUCS_SR = 44100 # Demucs default is 44.1kHz
# --- Device Selection ---
if torch:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")
else:
DEVICE = "cpu" # Fallback if torch failed import
# --- Helper Functions (cleanup_file, save_upload_file - same as before) ---
def cleanup_file(file_path: str):
try:
if file_path and os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Cleaned up temporary file: {file_path}")
except Exception as e:
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
_, file_extension = os.path.splitext(upload_file.filename)
if not file_extension: file_extension = ".wav"
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
try:
with open(temp_file_path, "wb") as buffer:
while content := await upload_file.read(1024 * 1024): buffer.write(content)
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
return temp_file_path
except Exception as e:
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
cleanup_file(temp_file_path)
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
finally:
await upload_file.close()
# --- Audio Loading/Saving for AI Models ---
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
"""Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
try:
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
if audio.ndim > 1 and audio.shape[0] > 5: # Check if likely stereo (more than 5 channels unlikely)
logger.warning(f"Detected {audio.shape[0]} channels, attempting to convert to mono by averaging.")
audio = np.mean(audio, axis=0) # Average channels if multi-channel
elif audio.ndim > 1:
audio = audio[0] # Take first channel if shape is like (1, N)
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
# Resample if necessary using librosa then convert back to tensor
if target_sr and orig_sr != target_sr:
if librosa is None: raise RuntimeError("Librosa is required for resampling but not installed.")
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
# Librosa works on numpy, so convert back temp.
audio_np = audio_tensor.numpy()
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
audio_tensor = torch.from_numpy(resampled_audio_np).float()
current_sr = target_sr
logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
else:
current_sr = orig_sr
# Ensure tensor is on the correct device
return audio_tensor.to(DEVICE), current_sr
except Exception as e:
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
"""Saves audio data (Tensor or NumPy array) to a temporary file."""
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
output_path = os.path.join(TEMP_DIR, output_filename)
try:
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
# Convert tensor to numpy array if needed
if isinstance(audio_data, torch.Tensor):
logger.debug("Converting output tensor to NumPy array.")
# Ensure tensor is on CPU before converting to numpy
audio_np = audio_data.detach().cpu().numpy()
elif isinstance(audio_data, np.ndarray):
audio_np = audio_data
else:
raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
# Ensure data is float32
if audio_np.dtype != np.float32:
logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
audio_np = audio_np.astype(np.float32)
# Clip values to avoid potential issues with formats expecting [-1, 1]
audio_np = np.clip(audio_np, -1.0, 1.0)
# Use soundfile (preferred for wav/flac)
if output_format.lower() in ['wav', 'flac']:
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
else:
# For lossy formats, use pydub
logger.debug(f"Using pydub to export to lossy format: {output_format}")
# Scale float32 [-1, 1] to int16 for pydub
audio_int16 = (audio_np * 32767).astype(np.int16)
# Create AudioSegment (assuming mono for now)
num_channels = 1 if audio_int16.ndim == 1 else audio_int16.shape[0] # Basic channel check
if num_channels > 1 : audio_int16=audio_int16[0] # Use first channel if > 1, needs better handling
segment = AudioSegment(
audio_int16.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_int16.dtype.itemsize,
channels=1 # Forcing mono currently
)
segment.export(output_path, format=output_format)
return output_path
except Exception as e:
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
cleanup_file(output_path)
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
# --- Synchronous AI Inference Functions ---
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
"""Synchronous wrapper for SpeechBrain enhancement model inference."""
if not model: raise ValueError("Enhancement model not loaded")
try:
logger.info(f"Running speech enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
# SpeechBrain models usually take tensors directly
# Add batch dimension if needed (most SB models expect batch)
if audio_tensor.ndim == 1:
audio_tensor = audio_tensor.unsqueeze(0)
# Move tensor to the same device as the model
model_device = next(model.parameters()).device
audio_tensor = audio_tensor.to(model_device)
with torch.no_grad():
# Use enhance_batch for batched input
enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
# Remove batch dimension from output before returning
enhanced_audio = enhanced_tensor.squeeze(0).cpu() # Move back to CPU
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
return enhanced_audio
except Exception as e:
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
raise
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
"""Synchronous wrapper for Demucs source separation model inference."""
if not model: raise ValueError("Separation model not loaded")
if not demucs: raise RuntimeError("Demucs library not available")
try:
logger.info(f"Running source separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
# Demucs expects audio as (batch, channels, samples)
# Ensure input tensor is on the correct device
model_device = next(model.parameters()).device
audio_tensor = audio_tensor.to(model_device)
# Add batch and channel dimensions if mono
if audio_tensor.ndim == 1:
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
elif audio_tensor.ndim == 2: # Should not happen often if load_audio ensures mono tensor
logger.warning("Input tensor has 2 dims, assuming (batch, samples), adding channel dim.")
audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
# Ensure correct number of channels expected by the model (usually 2)
if audio_tensor.shape[1] != model.audio_channels:
logger.warning(f"Model expects {model.audio_channels} channels, input has {audio_tensor.shape[1]}. Repeating mono channel.")
audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1) # Repeat mono to match expected channels
logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
with torch.no_grad():
# Use demucs.apply.apply_model which handles chunking etc.
# requires ref = audio_tensor.mean(0) # Average channels for reference
# sources = demucs.apply.apply_model(model, audio_tensor[0], device=model_device, shifts=1, split=True, overlap=0.25)[0] # Process first batch item
# OR direct model call if simpler:
sources = model(audio_tensor)[0] # Output shape (stems, channels, samples) - remove batch dim [0]
logger.debug(f"Raw separated sources tensor shape: {sources.shape}") # Should be (num_stems, channels, samples)
# Map stems based on the model's sources list
# Default for htdemucs: drums, bass, other, vocals
stem_map = {name: sources[i] for i, name in enumerate(model.sources)}
# Convert back to mono for simplicity (average channels) and move to CPU
output_stems = {}
for name, data in stem_map.items():
output_stems[name] = data.mean(dim=0).detach().cpu() # Average channels, detach, move to CPU
logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
return output_stems
except Exception as e:
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
raise
# --- Model Loading Function ---
def load_hf_models():
"""Loads AI models at startup using correct libraries."""
global enhancement_models, separation_models
if torch is None or speechbrain is None or demucs is None:
logger.error("Core AI libraries (torch, speechbrain, demucs) not available. Skipping model loading.")
return
# --- Load Enhancement Model (SpeechBrain) ---
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
try:
logger.info(f"Loading enhancement model: {enhancement_model_hparams} (using SpeechBrain)...")
# Ensure SpeechBrain downloads to a writable location if needed (optional)
# savedir = os.path.join(TEMP_DIR, "speechbrain_models")
# os.makedirs(savedir, exist_ok=True)
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
source=enhancement_model_hparams,
# savedir=savedir, # Specify download dir if needed
run_opts={"device": DEVICE} # Pass device option
)
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer # Store with consistent key
logger.info(f"Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {DEVICE}.")
except Exception as e:
logger.error(f"Failed to load enhancement model '{enhancement_model_hparams}': {e}", exc_info=True)
# --- Load Separation Model (Demucs) ---
# Using a standard pretrained model name from the demucs package
separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs" or "mdx_extra_q"
try:
logger.info(f"Loading separation model: {separation_model_name} (using Demucs package)...")
# This automatically handles downloading the model checkpoint
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
separation_models[SEPARATION_MODEL_KEY] = separator # Store with consistent key
logger.info(f"Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {DEVICE}.")
except Exception as e:
logger.error(f"Failed to load separation model '{separation_model_name}': {e}", exc_info=True)
logger.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
# --- FastAPI App ---
app = FastAPI(
title="AI Audio Editor API",
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries (torch, speechbrain, demucs).",
version="2.1.0", # Incremented version
)
@app.on_event("startup")
async def startup_event():
logger.info("Application startup: Loading AI models...")
# Run blocking model load in thread
await asyncio.to_thread(load_hf_models)
logger.info("Model loading process finished (check logs for success/failure).")
# --- API Endpoints ---
@app.get("/", tags=["General"])
def read_root():
# ... (root endpoint remains the same) ...
features = ["/trim", "/concat", "/volume", "/convert"]
ai_features = []
if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
if separation_models: ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY})")
return {
"message": "Welcome to the AI Audio Editor API.",
"basic_features": features,
"ai_features": ai_features if ai_features else "None available (check startup logs)",
"notes": "Requires FFmpeg. AI features require specific models loaded at startup."
}
# --- Basic Editing Endpoints ---
# (Add /trim, /concat, /volume, /convert endpoints here - same logic as before)
# Make sure they use the updated cleanup_file and save_upload_file helpers.
# ...
# --- AI Endpoints (Corrected) ---
@app.post("/enhance", tags=["AI Editing"])
async def enhance_speech(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
# Model ID is less relevant now if only one is loaded, but keep for future flexibility
model_key: str = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use."),
output_format: str = Form("wav", description="Output format (wav, flac recommended).")
):
"""Enhances speech audio using a pre-loaded SpeechBrain model."""
if torch is None or speechbrain is None:
raise HTTPException(status_code=501, detail="AI processing libraries (torch, speechbrain) not available.")
if model_key not in enhancement_models:
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
loaded_model = enhancement_models[model_key]
logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
input_path = await save_upload_file(file, prefix="enhance_in_")
background_tasks.add_task(cleanup_file, input_path)
output_path = None
try:
# Load audio as tensor, ensure correct SR
# SpeechBrain Sepformer expects 16kHz
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
logger.info("Submitting enhancement task to background thread...")
enhanced_audio_tensor = await asyncio.to_thread(
_run_enhancement_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
)
logger.info("Enhancement task completed.")
# Save the result (tensor output from enhancer)
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format) # Save at model's SR
background_tasks.add_task(cleanup_file, output_path)
return FileResponse(
path=output_path,
media_type=f"audio/{output_format}",
filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
)
except Exception as e:
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
if output_path: cleanup_file(output_path)
if isinstance(e, HTTPException): raise e
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
@app.post("/separate", tags=["AI Editing"])
async def separate_sources(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Music audio file to separate into stems."),
model_key: str = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use."),
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
):
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
if torch is None or demucs is None:
raise HTTPException(status_code=501, detail="AI processing libraries (torch, demucs) not available.")
if model_key not in separation_models:
logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
loaded_model = separation_models[model_key]
valid_stems = set(loaded_model.sources) # Get stems directly from loaded model
requested_stems = set(s.lower() for s in stems)
if not requested_stems.issubset(valid_stems):
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Model '{model_key}' provides: {', '.join(valid_stems)}")
logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
input_path = await save_upload_file(file, prefix="separate_in_")
background_tasks.add_task(cleanup_file, input_path)
stem_output_paths: Dict[str, str] = {}
zip_buffer = None
try:
# Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
logger.info("Submitting separation task to background thread...")
all_separated_stems_tensors = await asyncio.to_thread(
_run_separation_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
)
logger.info("Separation task completed.")
# --- Create ZIP file in memory ---
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
# Save only the requested stems
for stem_name in requested_stems:
if stem_name in all_separated_stems_tensors:
stem_tensor = all_separated_stems_tensors[stem_name]
# Save stem temporarily (save_hf_audio handles tensor)
# Use the model's native sampling rate for output
stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
stem_output_paths[stem_name] = stem_path
background_tasks.add_task(cleanup_file, stem_path)
archive_name = f"{stem_name}_{os.path.splitext(file.filename)[0]}.{output_format}"
zipf.write(stem_path, arcname=archive_name)
logger.info(f"Added '{archive_name}' to ZIP.")
else:
# This case should be prevented by the earlier validation
logger.warning(f"Requested stem '{stem_name}' not found in model output (should not happen).")
zip_buffer.seek(0)
zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
)
except Exception as e:
logger.error(f"Error during separation operation: {e}", exc_info=True)
for path in stem_output_paths.values(): cleanup_file(path)
if zip_buffer: zip_buffer.close()
if isinstance(e, HTTPException): raise e
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
# --- Add back the basic editing endpoints (/trim, /concat, /volume, /convert) here ---
# ... (Remember to include them) ...
# --- How to Run ---
# 1. Ensure FFmpeg is installed.
# 2. Save code as `app.py`. Create/update `requirements.txt`.
# 3. Install: `pip install -r requirements.txt` (May take significant time/space!)
# 4. Run: `uvicorn app:app --reload --host 0.0.0.0`