Aiaudio / app.py
Athspi's picture
Update app.py
2c84da8 verified
raw
history blame
45.5 kB
# app.py
import os
import uuid
import tempfile
import logging
import asyncio
from typing import List, Optional, Dict, Any
import traceback # For detailed error logging
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
import io
import zipfile
# --- Basic Editing Imports ---
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
# --- AI & Advanced Audio Imports ---
# Add extra logging around imports
logger_init = logging.getLogger("AppInit")
logger_init.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Create console handler and set level to info
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
# Avoid adding handler multiple times if script reloads
if not logger_init.handlers:
logger_init.addHandler(ch)
AI_LIBS_AVAILABLE = False
try:
logger_init.info("Importing torch...")
import torch
logger_init.info("Importing soundfile...")
import soundfile as sf
logger_init.info("Importing numpy...")
import numpy as np
logger_init.info("Importing librosa...")
import librosa
logger_init.info("Importing speechbrain...")
import speechbrain.pretrained
logger_init.info("Importing demucs...")
import demucs.separate
import demucs.apply
logger_init.info("AI and advanced audio libraries imported successfully.")
AI_LIBS_AVAILABLE = True
except ImportError as e:
logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed correctly.")
logger_init.error("AI features will be unavailable.")
# Define placeholders so the rest of the code doesn't break completely on import error
torch = None
sf = None
np = None
librosa = None
speechbrain = None
demucs = None
# --- Configuration & Setup ---
TEMP_DIR = tempfile.gettempdir()
# Attempt to create temp dir if it doesn't exist (useful in some environments)
try:
os.makedirs(TEMP_DIR, exist_ok=True)
except OSError as e:
logger_init.error(f"Could not create temporary directory {TEMP_DIR}: {e}")
# Fallback or raise an error depending on desired behavior
TEMP_DIR = "." # Use current directory as fallback (less ideal)
logger_init.warning(f"Using current directory '{TEMP_DIR}' for temporary files.")
# Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
# This logger will be used by endpoint handlers
logger = logging.getLogger(__name__)
# --- Global Variables for Loaded Models ---
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
# Choose a default Demucs model (htdemucs is good quality)
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
enhancement_models: Dict[str, Any] = {}
separation_models: Dict[str, Any] = {}
# Target sampling rates (confirm from model specifics if necessary)
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
DEMUCS_SR = 44100 # Demucs default is 44.1kHz
# --- Device Selection ---
if AI_LIBS_AVAILABLE and torch:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger_init.info(f"Selected device for AI models: {DEVICE}")
else:
DEVICE = "cpu" # Fallback if torch failed import
logger_init.info("Torch not available or AI libs failed import, defaulting device to CPU.")
# --- Helper Functions ---
def cleanup_file(file_path: str):
"""Safely remove a file."""
try:
if file_path and isinstance(file_path, str) and os.path.exists(file_path):
os.remove(file_path)
# logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
except Exception as e:
# Log error but don't crash the cleanup process for other files
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
"""Saves an uploaded file to a temporary location and returns the path."""
if not upload_file or not upload_file.filename:
raise HTTPException(status_code=400, detail="Invalid file upload object.")
_, file_extension = os.path.splitext(upload_file.filename)
# Default to .wav if no extension, as it's widely compatible for loading
if not file_extension: file_extension = ".wav"
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
try:
logger.debug(f"Attempting to save uploaded file to: {temp_file_path}")
with open(temp_file_path, "wb") as buffer:
# Read chunk by chunk for large files
while content := await upload_file.read(1024 * 1024): # 1MB chunks
buffer.write(content)
logger.info(f"Saved uploaded file '{upload_file.filename}' ({upload_file.content_type}) to temp path: {temp_file_path}")
return temp_file_path
except Exception as e:
logger.error(f"Failed to save uploaded file '{upload_file.filename}' to {temp_file_path}: {e}", exc_info=True)
cleanup_file(temp_file_path) # Attempt cleanup if saving failed
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
finally:
# Ensure file is closed even if saving fails mid-way
try:
await upload_file.close()
except Exception:
pass # Ignore errors during close if already failed
# --- Audio Loading/Saving Functions ---
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
"""Loads audio using soundfile, converts to mono float32 Torch tensor, optionally resamples."""
if not AI_LIBS_AVAILABLE:
raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
if not os.path.exists(file_path):
raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found at {file_path}")
try:
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
logger.info(f"Loaded '{os.path.basename(file_path)}' - SR={orig_sr}, Shape={audio.shape}, dtype={audio.dtype}")
# Ensure mono
if audio.ndim > 1:
# Check which dimension is smaller (likely channels)
channel_dim = np.argmin(audio.shape)
if audio.shape[channel_dim] > 1 and audio.shape[channel_dim] < 10: # Heuristic: <10 channels
logger.info(f"Detected {audio.shape[channel_dim]} channels. Converting to mono by averaging axis {channel_dim}.")
audio = np.mean(audio, axis=channel_dim)
else: # Fallback or if shape is ambiguous (e.g., very short stereo)
logger.warning(f"Audio has shape {audio.shape}. Taking first channel/element assuming mono or channel-first.")
audio = audio[0] if channel_dim == 0 else audio[:, 0] # Select first index of the likely channel dimension
logger.debug(f"Shape after mono conversion: {audio.shape}")
# Ensure it's now 1D
audio = audio.flatten()
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
# Resample if necessary using librosa
current_sr = orig_sr
if target_sr and orig_sr != target_sr:
if librosa is None: raise RuntimeError("Librosa missing for resampling")
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
# Librosa works on numpy
audio_np = audio_tensor.numpy()
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr, res_type='kaiser_best') # Specify resampling type
audio_tensor = torch.from_numpy(resampled_audio_np).float()
current_sr = target_sr
logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
# Ensure tensor is on the correct device
return audio_tensor.to(DEVICE), current_sr
except sf.SoundFileError as sf_err:
logger.error(f"SoundFileError loading {file_path}: {sf_err}", exc_info=True)
cleanup_file(file_path)
raise HTTPException(status_code=415, detail=f"Could not decode audio file: {os.path.basename(file_path)}. Unsupported format or corrupt file. Error: {sf_err}")
except Exception as e:
logger.error(f"Unexpected error loading/processing audio file {file_path} for AI: {e}", exc_info=True)
cleanup_file(file_path)
raise HTTPException(status_code=500, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Check server logs.")
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
"""Saves audio data (Tensor or NumPy array) to a temporary file."""
if not AI_LIBS_AVAILABLE:
raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
output_path = os.path.join(TEMP_DIR, output_filename)
try:
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
# Convert tensor to numpy array if needed
if isinstance(audio_data, torch.Tensor):
logger.debug("Converting output tensor to NumPy array.")
# Ensure tensor is on CPU before converting to numpy
audio_np = audio_data.detach().cpu().numpy()
elif isinstance(audio_data, np.ndarray):
audio_np = audio_data
else:
raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
# Ensure data is float32
if audio_np.dtype != np.float32:
logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
audio_np = audio_np.astype(np.float32)
# Clip values to avoid potential issues with formats expecting [-1, 1]
audio_np = np.clip(audio_np, -1.0, 1.0)
# Ensure audio is 1D (mono) before saving with soundfile or pydub conversion
if audio_np.ndim > 1:
logger.warning(f"Output audio data has {audio_np.ndim} dimensions, attempting to flatten or take first dimension.")
# Try averaging channels if shape suggests stereo/multi-channel
channel_dim = np.argmin(audio_np.shape)
if audio_np.shape[channel_dim] > 1 and audio_np.shape[channel_dim] < 10:
audio_np = np.mean(audio_np, axis=channel_dim)
else: # Otherwise just flatten
audio_np = audio_np.flatten()
# Use soundfile (preferred for wav/flac)
if output_format.lower() in ['wav', 'flac']:
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
else:
# For lossy formats, use pydub
logger.debug(f"Using pydub to export to lossy format: {output_format}")
# Scale float32 [-1, 1] to int16 for pydub
audio_int16 = (audio_np * 32767).astype(np.int16)
segment = AudioSegment(
audio_int16.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_int16.dtype.itemsize,
channels=1 # Assuming mono after processing above
)
# Pydub might need explicit ffmpeg path in some envs
# AudioSegment.converter = "/path/to/ffmpeg" # Uncomment and set path if needed
segment.export(output_path, format=output_format)
logger.info(f"Successfully saved AI audio to {output_path}")
return output_path
except Exception as e:
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
cleanup_file(output_path) # Attempt cleanup on saving failure
raise HTTPException(status_code=500, detail=f"Failed to save processed audio to format '{output_format}'.")
# --- Pydub Loading/Exporting (for basic edits) ---
def load_audio_pydub(file_path: str) -> AudioSegment:
"""Loads an audio file using pydub."""
if not os.path.exists(file_path):
raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found (pydub) at {file_path}")
try:
logger.debug(f"Loading audio with pydub: {file_path}")
# Explicitly provide format if possible, helps pydub sometimes
file_ext = os.path.splitext(file_path)[1][1:].lower()
if file_ext:
audio = AudioSegment.from_file(file_path, format=file_ext)
else:
audio = AudioSegment.from_file(file_path) # Let pydub detect
logger.info(f"Loaded audio using pydub from: {file_path}")
return audio
except CouldntDecodeError as e:
logger.warning(f"Pydub CouldntDecodeError for {file_path}: {e}")
cleanup_file(file_path)
raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}")
except Exception as e:
logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True)
cleanup_file(file_path)
raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}")
def export_audio_pydub(audio: AudioSegment, format: str) -> str:
"""Exports a Pydub AudioSegment to a temporary file and returns the path."""
output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
output_path = os.path.join(TEMP_DIR, output_filename)
try:
logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}")
audio.export(output_path, format=format.lower())
return output_path
except Exception as e:
logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True)
cleanup_file(output_path) # Cleanup if export failed
raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.")
# --- Synchronous AI Inference Functions ---
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
"""Synchronous wrapper for SpeechBrain enhancement model inference."""
if not AI_LIBS_AVAILABLE or not model: raise ValueError("Enhancement model/libs not available")
try:
logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
model_device = next(model.parameters()).device # Check model's current device
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
# Add batch dimension if model expects it (most do)
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
with torch.no_grad():
# Check if model expects lengths parameter
enhance_method = getattr(model, "enhance_batch", getattr(model, "forward", None))
if "lengths" in enhance_method.__code__.co_varnames:
enhanced_tensor = enhance_method(audio_tensor, lengths=torch.tensor([audio_tensor.shape[-1]]).to(model_device))
else:
enhanced_tensor = enhance_method(audio_tensor)
# Remove batch dimension from output before returning, move back to CPU
enhanced_audio = enhanced_tensor.squeeze(0).cpu()
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
return enhanced_audio
except Exception as e:
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
raise # Re-raise to be caught by the async wrapper
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
"""Synchronous wrapper for Demucs source separation model inference."""
if not AI_LIBS_AVAILABLE or not model: raise ValueError("Separation model/libs not available")
if not demucs: raise RuntimeError("Demucs library missing")
try:
logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
model_device = next(model.parameters()).device
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
# Demucs expects audio as (batch, channels, samples)
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
# Repeat channel if model expects stereo but input is mono
if audio_tensor.shape[1] != model.audio_channels:
if audio_tensor.shape[1] == 1:
logger.info(f"Model expects {model.audio_channels} channels, input is mono. Repeating channel.")
audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
else:
raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")
logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
with torch.no_grad():
# Use demucs.apply.apply_model for handling chunking etc.
# Requires input shape (channels, samples) - process first batch item
audio_to_process = audio_tensor.squeeze(0)
# Note: shifts=1, split=True are common defaults for quality
out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25, progress=False) # Disable progress bar in logs
# Output shape (stems, channels, samples)
logger.debug(f"Raw separated sources tensor shape: {out.shape}")
# Map stems based on the model's sources list
stem_map = {name: out[i] for i, name in enumerate(model.sources)}
# Convert back to mono for simplicity (average channels) and move to CPU
output_stems = {}
for name, data in stem_map.items():
# Average channels, detach, move to CPU
output_stems[name] = data.mean(dim=0).detach().cpu()
logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
return output_stems
except Exception as e:
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
raise
# --- Model Loading Function (Enhanced Logging) ---
def load_hf_models():
"""Loads AI models at startup using correct libraries."""
logger_load = logging.getLogger("ModelLoader") # Use specific logger
logger_load.setLevel(logging.INFO)
# Ensure handler is attached if logger is newly created
if not logger_load.handlers and ch: logger_load.addHandler(ch)
global enhancement_models, separation_models
if not AI_LIBS_AVAILABLE:
logger_load.error("Core AI libraries not available. Cannot load AI models.")
return
load_success_flags = {"enhancement": False, "separation": False}
# --- Load Enhancement Model ---
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
try:
logger_load.info(f"Attempting load on device: {DEVICE}")
# Consider adding savedir if cache issues arise in HF Spaces
# savedir_sb = os.path.join(TEMP_DIR, "speechbrain_models")
# os.makedirs(savedir_sb, exist_ok=True)
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
source=enhancement_model_hparams,
# savedir=savedir_sb,
run_opts={"device": DEVICE}
)
model_device = next(enhancer.parameters()).device
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
load_success_flags["enhancement"] = True
except Exception as e:
logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False)
logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
logger_load.warning("Enhancement features will be unavailable.")
# --- Load Separation Model ---
separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs"
logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
try:
logger_load.info(f"Attempting load on device: {DEVICE}")
# This automatically handles downloading the model checkpoint via demucs package
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
model_device = next(separator.parameters()).device
separation_models[SEPARATION_MODEL_KEY] = separator
logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
logger_load.info(f"Separation model available sources: {separator.sources}")
load_success_flags["separation"] = True
except Exception as e:
logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
logger_load.error(f"Traceback: {traceback.format_exc()}")
logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs). Check resource limits (RAM).")
logger_load.warning("Separation features will be unavailable.")
logger_load.info(f"--- Model loading attempts finished ---")
logger_load.info(f"Enhancement Model Loaded: {load_success_flags['enhancement']}")
logger_load.info(f"Separation Model Loaded: {load_success_flags['separation']}")
# --- FastAPI App ---
app = FastAPI(
title="AI Audio Editor API",
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
version="2.1.2", # Incremented version
)
@app.on_event("startup")
async def startup_event():
# Use the init logger for startup messages
logger_init.info("--- FastAPI Application Startup ---")
if AI_LIBS_AVAILABLE:
logger_init.info("AI Libraries imported successfully. Loading models in background thread...")
# Run blocking model load in thread
await asyncio.to_thread(load_hf_models)
logger_init.info("Background model loading task finished (check ModelLoader logs above for details).")
else:
logger_init.error("AI Libraries failed to import during init. AI features will be disabled.")
logger_init.info("--- Startup sequence complete ---")
# --- API Endpoints ---
@app.get("/", tags=["General"])
def read_root():
"""Root endpoint providing a welcome message and status of loaded models."""
features = ["/trim", "/concat", "/volume", "/convert"]
ai_features_status = {}
if AI_LIBS_AVAILABLE:
if enhancement_models:
ai_features_status[ENHANCEMENT_MODEL_KEY] = "Loaded"
else:
ai_features_status[ENHANCEMENT_MODEL_KEY] = "Failed to load (check startup logs)"
if separation_models:
model = separation_models.get(SEPARATION_MODEL_KEY)
sources_str = ', '.join(model.sources) if model else 'N/A'
ai_features_status[SEPARATION_MODEL_KEY] = f"Loaded (Sources: {sources_str})"
else:
ai_features_status[SEPARATION_MODEL_KEY] = "Failed to load (check startup logs)"
else:
ai_features_status["AI Status"] = "Libraries Failed Import"
return {
"message": "Welcome to the AI Audio Editor API.",
"status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
"ai_models_status": ai_features_status,
"basic_endpoints": features,
"notes": "Requires FFmpeg. AI features require successful model loading at startup."
}
# --- Basic Editing Endpoints ---
@app.post("/trim", tags=["Basic Editing"])
async def trim_audio(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Audio file to trim."),
start_ms: int = Form(..., ge=0, description="Start time in milliseconds."),
end_ms: int = Form(..., gt=0, description="End time in milliseconds.") # Ensure end > 0
):
"""Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub."""
if end_ms <= start_ms:
raise HTTPException(status_code=422, detail="End time (end_ms) must be greater than start time (start_ms).")
logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
input_path = await save_upload_file(file, prefix="trim_in_")
# Schedule cleanup immediately after saving, even if loading fails later
background_tasks.add_task(cleanup_file, input_path)
output_path = None # Define before try block
try:
audio = load_audio_pydub(input_path) # Can raise HTTPException
trimmed_audio = audio[start_ms:end_ms]
logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
# Determine original format for export
original_format = os.path.splitext(file.filename)[1][1:].lower()
# Use mp3 as default only if no extension or if it's 'tmp' etc.
if not original_format or len(original_format) > 5: # Basic check for valid extension length
original_format = "mp3"
logger.warning(f"Using default export format 'mp3' for input '{file.filename}'")
output_path = export_audio_pydub(trimmed_audio, original_format) # Can raise HTTPException
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
# Create a more informative filename
output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}"
return FileResponse(
path=output_path,
media_type=f"audio/{original_format}", # Best guess for media type
filename=output_filename
)
except HTTPException as http_exc:
# If load/export raised HTTPException, re-raise it
# Cleanup might have already been scheduled, background tasks handle errors
logger.error(f"HTTP Exception during trim: {http_exc.detail}")
if output_path: cleanup_file(output_path) # Try immediate cleanup if output exists
raise http_exc
except Exception as e:
# Catch other unexpected errors during trimming logic
logger.error(f"Unexpected error during trim operation: {e}", exc_info=True)
if output_path: cleanup_file(output_path)
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during trimming: {str(e)}")
@app.post("/concat", tags=["Basic Editing"])
async def concatenate_audio(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
):
"""Concatenates two or more audio files sequentially using Pydub."""
if len(files) < 2:
raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
input_paths = [] # Keep track of all saved input file paths
output_path = None # Define before try block
try:
combined_audio: Optional[AudioSegment] = None
for i, file in enumerate(files):
if not file or not file.filename:
logger.warning(f"Skipping invalid file upload at index {i}.")
continue # Skip potentially empty file entries
input_path = await save_upload_file(file, prefix=f"concat_{i}_in_")
input_paths.append(input_path)
# Schedule cleanup for this specific input file immediately
background_tasks.add_task(cleanup_file, input_path)
try:
audio = load_audio_pydub(input_path)
if combined_audio is None:
combined_audio = audio
logger.info(f"Starting concatenation with '{file.filename}' ({len(combined_audio)}ms)")
else:
logger.info(f"Adding '{file.filename}' ({len(audio)}ms)")
combined_audio += audio
except HTTPException as load_exc:
# Log error but continue trying to load other files if possible
logger.error(f"Failed to load file '{file.filename}' for concatenation: {load_exc.detail}. Skipping this file.")
except Exception as load_exc:
logger.error(f"Unexpected error loading file '{file.filename}' for concatenation: {load_exc}. Skipping this file.", exc_info=True)
if combined_audio is None:
raise HTTPException(status_code=400, detail="No valid audio files could be loaded and combined.")
logger.info(f"Final concatenated audio length: {len(combined_audio)}ms")
output_path = export_audio_pydub(combined_audio, output_format) # Can raise HTTPException
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
# Determine a reasonable output filename
first_valid_filename = files[0].filename if files and files[0] else "audio"
first_filename_base = os.path.splitext(first_valid_filename)[0]
output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
return FileResponse(
path=output_path,
media_type=f"audio/{output_format}",
filename=output_filename
)
except HTTPException as http_exc:
# If load/export raised HTTPException, re-raise it
logger.error(f"HTTP Exception during concat: {http_exc.detail}")
# Cleanup for output path, inputs are handled by background tasks
if output_path: cleanup_file(output_path)
raise http_exc
except Exception as e:
# Catch other unexpected errors during combining logic
logger.error(f"Unexpected error during concat operation: {e}", exc_info=True)
if output_path: cleanup_file(output_path)
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during concatenation: {str(e)}")
@app.post("/volume", tags=["Basic Editing"])
async def change_volume(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Audio file to adjust volume for."),
change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.")
):
"""Adjusts the volume of an audio file by a specified decibel amount using Pydub."""
logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
input_path = await save_upload_file(file, prefix="volume_in_")
background_tasks.add_task(cleanup_file, input_path)
output_path = None
try:
audio = load_audio_pydub(input_path)
# Check for potential silence before applying gain
if audio.dBFS == -float('inf'):
logger.warning(f"Input file '{file.filename}' appears to be silent. Applying volume change may have no effect.")
adjusted_audio = audio + change_db
logger.info(f"Volume adjusted by {change_db}dB.")
original_format = os.path.splitext(file.filename)[1][1:].lower()
if not original_format or len(original_format) > 5: original_format = "mp3"
output_path = export_audio_pydub(adjusted_audio, original_format)
background_tasks.add_task(cleanup_file, output_path)
# Create filename
sign = "+" if change_db >= 0 else ""
output_filename=f"volume_{sign}{change_db}dB_{os.path.splitext(file.filename)[0]}.{original_format}"
return FileResponse(
path=output_path,
media_type=f"audio/{original_format}",
filename=output_filename
)
except HTTPException as http_exc:
logger.error(f"HTTP Exception during volume change: {http_exc.detail}")
if output_path: cleanup_file(output_path)
raise http_exc
except Exception as e:
logger.error(f"Unexpected error during volume operation: {e}", exc_info=True)
if output_path: cleanup_file(output_path)
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during volume adjustment: {str(e)}")
@app.post("/convert", tags=["Basic Editing"])
async def convert_format(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Audio file to convert."),
output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
):
"""Converts an audio file to a different format using Pydub."""
# Define allowed formats explicitly
allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus', 'wma', 'aiff'} # Expanded list
output_format_lower = output_format.lower()
if output_format_lower not in allowed_formats:
raise HTTPException(status_code=422, detail=f"Invalid output format '{output_format}'. Allowed: {', '.join(sorted(list(allowed_formats)))}")
logger.info(f"Convert request: file='{file.filename}', output_format='{output_format_lower}'")
input_path = await save_upload_file(file, prefix="convert_in_")
background_tasks.add_task(cleanup_file, input_path)
output_path = None
try:
# Load using pydub, which handles many input formats
audio = load_audio_pydub(input_path)
logger.info(f"Successfully loaded '{file.filename}' for conversion.")
# Export using pydub
output_path = export_audio_pydub(audio, output_format_lower)
background_tasks.add_task(cleanup_file, output_path)
logger.info(f"Successfully exported to {output_format_lower}")
# Construct new filename
filename_base = os.path.splitext(file.filename)[0]
output_filename = f"{filename_base}_converted.{output_format_lower}"
# Determine media type (MIME type) - might need refinement for less common types
media_type_map = {
'mp3': 'audio/mpeg', 'wav': 'audio/wav', 'ogg': 'audio/ogg',
'flac': 'audio/flac', 'aac': 'audio/aac', 'm4a': 'audio/mp4', # m4a often uses mp4 container
'opus': 'audio/opus', 'wma':'audio/x-ms-wma', 'aiff':'audio/aiff'
}
media_type = media_type_map.get(output_format_lower, 'application/octet-stream') # Default binary if unknown
return FileResponse(
path=output_path,
media_type=media_type,
filename=output_filename
)
except HTTPException as http_exc:
logger.error(f"HTTP Exception during conversion: {http_exc.detail}")
if output_path: cleanup_file(output_path)
raise http_exc
except Exception as e:
logger.error(f"Unexpected error during convert operation: {e}", exc_info=True)
if output_path: cleanup_file(output_path)
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during format conversion: {str(e)}")
# --- AI Endpoints ---
@app.post("/enhance", tags=["AI Editing"])
async def enhance_speech(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
# Keep model_key optional for now, assumes default if only one loaded
model_key: Optional[str] = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use (defaults to primary)."),
output_format: str = Form("wav", description="Output format (wav, flac recommended).")
):
"""Enhances speech audio using a pre-loaded SpeechBrain model."""
if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
# Use the provided key or the default
actual_model_key = model_key or ENHANCEMENT_MODEL_KEY
if actual_model_key not in enhancement_models:
logger.error(f"Enhancement model key '{actual_model_key}' requested but model not loaded.")
raise HTTPException(status_code=503, detail=f"Enhancement model '{actual_model_key}' is not loaded or available. Check server startup logs.")
loaded_model = enhancement_models[actual_model_key]
logger.info(f"Enhance request: file='{file.filename}', model='{actual_model_key}', format='{output_format}'")
input_path = await save_upload_file(file, prefix="enhance_in_")
background_tasks.add_task(cleanup_file, input_path)
output_path = None
try:
# Load audio as tensor, ensure correct SR (16kHz)
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
logger.info("Submitting enhancement task to background thread...")
enhanced_audio_tensor = await asyncio.to_thread(
_run_enhancement_sync, loaded_model, audio_tensor, current_sr
)
logger.info("Enhancement task completed.")
# Save the result (tensor output from enhancer at 16kHz)
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
background_tasks.add_task(cleanup_file, output_path)
output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
media_type = f"audio/{output_format}" # Basic media type
return FileResponse(path=output_path, media_type=media_type, filename=output_filename)
except HTTPException as http_exc:
logger.error(f"HTTP Exception during enhancement: {http_exc.detail}")
if output_path: cleanup_file(output_path)
raise http_exc
except Exception as e:
logger.error(f"Unexpected error during enhancement operation: {e}", exc_info=True)
if output_path: cleanup_file(output_path)
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during enhancement: {str(e)}")
@app.post("/separate", tags=["AI Editing"])
async def separate_sources(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Music audio file to separate into stems."),
model_key: Optional[str] = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use (defaults to primary)."),
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
):
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
actual_model_key = model_key or SEPARATION_MODEL_KEY
if actual_model_key not in separation_models:
logger.error(f"Separation model key '{actual_model_key}' requested but model not loaded.")
raise HTTPException(status_code=503, detail=f"Separation model '{actual_model_key}' is not loaded or available. Check server startup logs.")
loaded_model = separation_models[actual_model_key]
valid_stems = set(loaded_model.sources)
requested_stems = set(s.lower() for s in stems)
# Check if *any* requested stem is valid
if not requested_stems:
raise HTTPException(status_code=422, detail="No stems requested for separation.")
# Check if *all* requested stems are valid for this model
invalid_stems = requested_stems - valid_stems
if invalid_stems:
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested: {', '.join(invalid_stems)}. Model '{actual_model_key}' provides: {', '.join(valid_stems)}")
logger.info(f"Separate request: file='{file.filename}', model='{actual_model_key}', stems={requested_stems}, format='{output_format}'")
input_path = await save_upload_file(file, prefix="separate_in_")
background_tasks.add_task(cleanup_file, input_path)
stem_output_paths: Dict[str, str] = {} # Store paths of successfully saved stems
zip_buffer = io.BytesIO(); zipf = None # Initialize zip buffer and file object
try:
# Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
logger.info("Submitting separation task to background thread...")
all_separated_stems_tensors = await asyncio.to_thread(
_run_separation_sync, loaded_model, audio_tensor, current_sr
)
logger.info("Separation task completed successfully.")
# --- Create ZIP file in memory ---
logger.info("Creating ZIP archive in memory...")
zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
files_added_to_zip = 0
for stem_name in requested_stems:
if stem_name in all_separated_stems_tensors:
stem_tensor = all_separated_stems_tensors[stem_name]
stem_path = None # Define stem_path before inner try
try:
# Save stem temporarily (save_hf_audio handles tensor)
# Use the model's native sampling rate for output (DEMUCS_SR)
stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
stem_output_paths[stem_name] = stem_path
# Schedule cleanup AFTER zip is potentially sent
background_tasks.add_task(cleanup_file, stem_path)
# Use a simpler archive name within the zip
archive_name = f"{stem_name}.{output_format}"
zipf.write(stem_path, arcname=archive_name)
files_added_to_zip += 1
logger.info(f"Added '{archive_name}' to ZIP.")
except Exception as save_err:
# Log error saving/zipping this stem but continue with others
logger.error(f"Failed to save or add stem '{stem_name}' to zip: {save_err}", exc_info=True)
if stem_path: cleanup_file(stem_path) # Clean up if saved but couldn't zip
else:
# This case should be prevented by the earlier validation
logger.warning(f"Requested stem '{stem_name}' not found in model output (validation error?).")
zipf.close() # Close zip file BEFORE seeking/reading
zipf = None # Clear variable to indicate closed
if files_added_to_zip == 0:
logger.error("Failed to add any requested stems to the ZIP archive.")
raise HTTPException(status_code=500, detail="Failed to generate any of the requested stems.")
zip_buffer.seek(0) # Rewind buffer pointer for reading
# Create final ZIP filename
zip_filename = f"separated_{actual_model_key}_{os.path.splitext(file.filename)[0]}.zip"
logger.info(f"Sending ZIP file: {zip_filename}")
return StreamingResponse(
iter([zip_buffer.getvalue()]), # StreamingResponse needs an iterator
media_type="application/zip",
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
)
except HTTPException as http_exc:
logger.error(f"HTTP Exception during separation: {http_exc.detail}")
if zipf: zipf.close() # Ensure zipfile is closed
if zip_buffer: zip_buffer.close()
for path in stem_output_paths.values(): cleanup_file(path) # Cleanup successful stems
raise http_exc
except Exception as e:
logger.error(f"Unexpected error during separation operation: {e}", exc_info=True)
if zipf: zipf.close()
if zip_buffer: zip_buffer.close()
for path in stem_output_paths.values(): cleanup_file(path)
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during separation: {str(e)}")
finally:
# Ensure buffer is closed if not already done
if zip_buffer and not zip_buffer.closed:
zip_buffer.close()
# --- How to Run ---
# 1. Ensure FFmpeg is installed and accessible in your PATH.
# 2. Save this code as `app.py`.
# 3. Create `requirements.txt` (including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs, python-multipart, protobuf).
# 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!).
# 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Use port 7860 for HF Spaces default, remove --reload for production).
#
# --- WARNING ---
# - AI models require SIGNIFICANT RAM (often 8GB+) and CPU/GPU. Inference can be SLOW (minutes). Free HF Spaces might time out or lack resources.
# - First run downloads models (can take a long time/lots of disk space).
# - Ensure model names (e.g., "htdemucs") are correct.
# - MONITOR STARTUP LOGS carefully for model loading success/failure. Errors here will cause 503 errors later.