Spaces:
Running
Running
# app.py | |
import os | |
import uuid | |
import tempfile | |
import logging | |
import asyncio | |
from typing import List, Optional, Dict, Any | |
import traceback # For detailed error logging | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query | |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
import io | |
import zipfile | |
# --- Basic Editing Imports --- | |
from pydub import AudioSegment | |
from pydub.exceptions import CouldntDecodeError | |
# --- AI & Advanced Audio Imports --- | |
# Add extra logging around imports | |
logger_init = logging.getLogger("AppInit") | |
logger_init.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
# Create console handler and set level to info | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.INFO) | |
ch.setFormatter(formatter) | |
# Avoid adding handler multiple times if script reloads | |
if not logger_init.handlers: | |
logger_init.addHandler(ch) | |
AI_LIBS_AVAILABLE = False | |
try: | |
logger_init.info("Importing torch...") | |
import torch | |
logger_init.info("Importing soundfile...") | |
import soundfile as sf | |
logger_init.info("Importing numpy...") | |
import numpy as np | |
logger_init.info("Importing librosa...") | |
import librosa | |
logger_init.info("Importing speechbrain...") | |
import speechbrain.pretrained | |
logger_init.info("Importing demucs...") | |
import demucs.separate | |
import demucs.apply | |
logger_init.info("AI and advanced audio libraries imported successfully.") | |
AI_LIBS_AVAILABLE = True | |
except ImportError as e: | |
logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True) | |
logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed correctly.") | |
logger_init.error("AI features will be unavailable.") | |
# Define placeholders so the rest of the code doesn't break completely on import error | |
torch = None | |
sf = None | |
np = None | |
librosa = None | |
speechbrain = None | |
demucs = None | |
# --- Configuration & Setup --- | |
TEMP_DIR = tempfile.gettempdir() | |
# Attempt to create temp dir if it doesn't exist (useful in some environments) | |
try: | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
except OSError as e: | |
logger_init.error(f"Could not create temporary directory {TEMP_DIR}: {e}") | |
# Fallback or raise an error depending on desired behavior | |
TEMP_DIR = "." # Use current directory as fallback (less ideal) | |
logger_init.warning(f"Using current directory '{TEMP_DIR}' for temporary files.") | |
# Configure main app logging (use the root logger setup by FastAPI/Uvicorn) | |
# This logger will be used by endpoint handlers | |
logger = logging.getLogger(__name__) | |
# --- Global Variables for Loaded Models --- | |
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer" | |
# Choose a default Demucs model (htdemucs is good quality) | |
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one | |
enhancement_models: Dict[str, Any] = {} | |
separation_models: Dict[str, Any] = {} | |
# Target sampling rates (confirm from model specifics if necessary) | |
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz | |
DEMUCS_SR = 44100 # Demucs default is 44.1kHz | |
# --- Device Selection --- | |
if AI_LIBS_AVAILABLE and torch: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger_init.info(f"Selected device for AI models: {DEVICE}") | |
else: | |
DEVICE = "cpu" # Fallback if torch failed import | |
logger_init.info("Torch not available or AI libs failed import, defaulting device to CPU.") | |
# --- Helper Functions --- | |
def cleanup_file(file_path: str): | |
"""Safely remove a file.""" | |
try: | |
if file_path and isinstance(file_path, str) and os.path.exists(file_path): | |
os.remove(file_path) | |
# logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise | |
except Exception as e: | |
# Log error but don't crash the cleanup process for other files | |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False) | |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str: | |
"""Saves an uploaded file to a temporary location and returns the path.""" | |
if not upload_file or not upload_file.filename: | |
raise HTTPException(status_code=400, detail="Invalid file upload object.") | |
_, file_extension = os.path.splitext(upload_file.filename) | |
# Default to .wav if no extension, as it's widely compatible for loading | |
if not file_extension: file_extension = ".wav" | |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}") | |
try: | |
logger.debug(f"Attempting to save uploaded file to: {temp_file_path}") | |
with open(temp_file_path, "wb") as buffer: | |
# Read chunk by chunk for large files | |
while content := await upload_file.read(1024 * 1024): # 1MB chunks | |
buffer.write(content) | |
logger.info(f"Saved uploaded file '{upload_file.filename}' ({upload_file.content_type}) to temp path: {temp_file_path}") | |
return temp_file_path | |
except Exception as e: | |
logger.error(f"Failed to save uploaded file '{upload_file.filename}' to {temp_file_path}: {e}", exc_info=True) | |
cleanup_file(temp_file_path) # Attempt cleanup if saving failed | |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}") | |
finally: | |
# Ensure file is closed even if saving fails mid-way | |
try: | |
await upload_file.close() | |
except Exception: | |
pass # Ignore errors during close if already failed | |
# --- Audio Loading/Saving Functions --- | |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]: | |
"""Loads audio using soundfile, converts to mono float32 Torch tensor, optionally resamples.""" | |
if not AI_LIBS_AVAILABLE: | |
raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.") | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found at {file_path}") | |
try: | |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False) | |
logger.info(f"Loaded '{os.path.basename(file_path)}' - SR={orig_sr}, Shape={audio.shape}, dtype={audio.dtype}") | |
# Ensure mono | |
if audio.ndim > 1: | |
# Check which dimension is smaller (likely channels) | |
channel_dim = np.argmin(audio.shape) | |
if audio.shape[channel_dim] > 1 and audio.shape[channel_dim] < 10: # Heuristic: <10 channels | |
logger.info(f"Detected {audio.shape[channel_dim]} channels. Converting to mono by averaging axis {channel_dim}.") | |
audio = np.mean(audio, axis=channel_dim) | |
else: # Fallback or if shape is ambiguous (e.g., very short stereo) | |
logger.warning(f"Audio has shape {audio.shape}. Taking first channel/element assuming mono or channel-first.") | |
audio = audio[0] if channel_dim == 0 else audio[:, 0] # Select first index of the likely channel dimension | |
logger.debug(f"Shape after mono conversion: {audio.shape}") | |
# Ensure it's now 1D | |
audio = audio.flatten() | |
# Convert numpy array to torch tensor | |
audio_tensor = torch.from_numpy(audio).float() | |
# Resample if necessary using librosa | |
current_sr = orig_sr | |
if target_sr and orig_sr != target_sr: | |
if librosa is None: raise RuntimeError("Librosa missing for resampling") | |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...") | |
# Librosa works on numpy | |
audio_np = audio_tensor.numpy() | |
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr, res_type='kaiser_best') # Specify resampling type | |
audio_tensor = torch.from_numpy(resampled_audio_np).float() | |
current_sr = target_sr | |
logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}") | |
# Ensure tensor is on the correct device | |
return audio_tensor.to(DEVICE), current_sr | |
except sf.SoundFileError as sf_err: | |
logger.error(f"SoundFileError loading {file_path}: {sf_err}", exc_info=True) | |
cleanup_file(file_path) | |
raise HTTPException(status_code=415, detail=f"Could not decode audio file: {os.path.basename(file_path)}. Unsupported format or corrupt file. Error: {sf_err}") | |
except Exception as e: | |
logger.error(f"Unexpected error loading/processing audio file {file_path} for AI: {e}", exc_info=True) | |
cleanup_file(file_path) | |
raise HTTPException(status_code=500, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Check server logs.") | |
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str: | |
"""Saves audio data (Tensor or NumPy array) to a temporary file.""" | |
if not AI_LIBS_AVAILABLE: | |
raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.") | |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}" | |
output_path = os.path.join(TEMP_DIR, output_filename) | |
try: | |
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})") | |
# Convert tensor to numpy array if needed | |
if isinstance(audio_data, torch.Tensor): | |
logger.debug("Converting output tensor to NumPy array.") | |
# Ensure tensor is on CPU before converting to numpy | |
audio_np = audio_data.detach().cpu().numpy() | |
elif isinstance(audio_data, np.ndarray): | |
audio_np = audio_data | |
else: | |
raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}") | |
# Ensure data is float32 | |
if audio_np.dtype != np.float32: | |
logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.") | |
audio_np = audio_np.astype(np.float32) | |
# Clip values to avoid potential issues with formats expecting [-1, 1] | |
audio_np = np.clip(audio_np, -1.0, 1.0) | |
# Ensure audio is 1D (mono) before saving with soundfile or pydub conversion | |
if audio_np.ndim > 1: | |
logger.warning(f"Output audio data has {audio_np.ndim} dimensions, attempting to flatten or take first dimension.") | |
# Try averaging channels if shape suggests stereo/multi-channel | |
channel_dim = np.argmin(audio_np.shape) | |
if audio_np.shape[channel_dim] > 1 and audio_np.shape[channel_dim] < 10: | |
audio_np = np.mean(audio_np, axis=channel_dim) | |
else: # Otherwise just flatten | |
audio_np = audio_np.flatten() | |
# Use soundfile (preferred for wav/flac) | |
if output_format.lower() in ['wav', 'flac']: | |
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper()) | |
else: | |
# For lossy formats, use pydub | |
logger.debug(f"Using pydub to export to lossy format: {output_format}") | |
# Scale float32 [-1, 1] to int16 for pydub | |
audio_int16 = (audio_np * 32767).astype(np.int16) | |
segment = AudioSegment( | |
audio_int16.tobytes(), | |
frame_rate=sampling_rate, | |
sample_width=audio_int16.dtype.itemsize, | |
channels=1 # Assuming mono after processing above | |
) | |
# Pydub might need explicit ffmpeg path in some envs | |
# AudioSegment.converter = "/path/to/ffmpeg" # Uncomment and set path if needed | |
segment.export(output_path, format=output_format) | |
logger.info(f"Successfully saved AI audio to {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True) | |
cleanup_file(output_path) # Attempt cleanup on saving failure | |
raise HTTPException(status_code=500, detail=f"Failed to save processed audio to format '{output_format}'.") | |
# --- Pydub Loading/Exporting (for basic edits) --- | |
def load_audio_pydub(file_path: str) -> AudioSegment: | |
"""Loads an audio file using pydub.""" | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found (pydub) at {file_path}") | |
try: | |
logger.debug(f"Loading audio with pydub: {file_path}") | |
# Explicitly provide format if possible, helps pydub sometimes | |
file_ext = os.path.splitext(file_path)[1][1:].lower() | |
if file_ext: | |
audio = AudioSegment.from_file(file_path, format=file_ext) | |
else: | |
audio = AudioSegment.from_file(file_path) # Let pydub detect | |
logger.info(f"Loaded audio using pydub from: {file_path}") | |
return audio | |
except CouldntDecodeError as e: | |
logger.warning(f"Pydub CouldntDecodeError for {file_path}: {e}") | |
cleanup_file(file_path) | |
raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}") | |
except Exception as e: | |
logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True) | |
cleanup_file(file_path) | |
raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}") | |
def export_audio_pydub(audio: AudioSegment, format: str) -> str: | |
"""Exports a Pydub AudioSegment to a temporary file and returns the path.""" | |
output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}" | |
output_path = os.path.join(TEMP_DIR, output_filename) | |
try: | |
logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}") | |
audio.export(output_path, format=format.lower()) | |
return output_path | |
except Exception as e: | |
logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True) | |
cleanup_file(output_path) # Cleanup if export failed | |
raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.") | |
# --- Synchronous AI Inference Functions --- | |
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor: | |
"""Synchronous wrapper for SpeechBrain enhancement model inference.""" | |
if not AI_LIBS_AVAILABLE or not model: raise ValueError("Enhancement model/libs not available") | |
try: | |
logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...") | |
model_device = next(model.parameters()).device # Check model's current device | |
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device) | |
# Add batch dimension if model expects it (most do) | |
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0) | |
with torch.no_grad(): | |
# Check if model expects lengths parameter | |
enhance_method = getattr(model, "enhance_batch", getattr(model, "forward", None)) | |
if "lengths" in enhance_method.__code__.co_varnames: | |
enhanced_tensor = enhance_method(audio_tensor, lengths=torch.tensor([audio_tensor.shape[-1]]).to(model_device)) | |
else: | |
enhanced_tensor = enhance_method(audio_tensor) | |
# Remove batch dimension from output before returning, move back to CPU | |
enhanced_audio = enhanced_tensor.squeeze(0).cpu() | |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})") | |
return enhanced_audio | |
except Exception as e: | |
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True) | |
raise # Re-raise to be caught by the async wrapper | |
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]: | |
"""Synchronous wrapper for Demucs source separation model inference.""" | |
if not AI_LIBS_AVAILABLE or not model: raise ValueError("Separation model/libs not available") | |
if not demucs: raise RuntimeError("Demucs library missing") | |
try: | |
logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...") | |
model_device = next(model.parameters()).device | |
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device) | |
# Demucs expects audio as (batch, channels, samples) | |
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N) | |
elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N) | |
# Repeat channel if model expects stereo but input is mono | |
if audio_tensor.shape[1] != model.audio_channels: | |
if audio_tensor.shape[1] == 1: | |
logger.info(f"Model expects {model.audio_channels} channels, input is mono. Repeating channel.") | |
audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1) | |
else: | |
raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})") | |
logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}") | |
with torch.no_grad(): | |
# Use demucs.apply.apply_model for handling chunking etc. | |
# Requires input shape (channels, samples) - process first batch item | |
audio_to_process = audio_tensor.squeeze(0) | |
# Note: shifts=1, split=True are common defaults for quality | |
out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25, progress=False) # Disable progress bar in logs | |
# Output shape (stems, channels, samples) | |
logger.debug(f"Raw separated sources tensor shape: {out.shape}") | |
# Map stems based on the model's sources list | |
stem_map = {name: out[i] for i, name in enumerate(model.sources)} | |
# Convert back to mono for simplicity (average channels) and move to CPU | |
output_stems = {} | |
for name, data in stem_map.items(): | |
# Average channels, detach, move to CPU | |
output_stems[name] = data.mean(dim=0).detach().cpu() | |
logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}") | |
return output_stems | |
except Exception as e: | |
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True) | |
raise | |
# --- Model Loading Function (Enhanced Logging) --- | |
def load_hf_models(): | |
"""Loads AI models at startup using correct libraries.""" | |
logger_load = logging.getLogger("ModelLoader") # Use specific logger | |
logger_load.setLevel(logging.INFO) | |
# Ensure handler is attached if logger is newly created | |
if not logger_load.handlers and ch: logger_load.addHandler(ch) | |
global enhancement_models, separation_models | |
if not AI_LIBS_AVAILABLE: | |
logger_load.error("Core AI libraries not available. Cannot load AI models.") | |
return | |
load_success_flags = {"enhancement": False, "separation": False} | |
# --- Load Enhancement Model --- | |
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement" | |
logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---") | |
try: | |
logger_load.info(f"Attempting load on device: {DEVICE}") | |
# Consider adding savedir if cache issues arise in HF Spaces | |
# savedir_sb = os.path.join(TEMP_DIR, "speechbrain_models") | |
# os.makedirs(savedir_sb, exist_ok=True) | |
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams( | |
source=enhancement_model_hparams, | |
# savedir=savedir_sb, | |
run_opts={"device": DEVICE} | |
) | |
model_device = next(enhancer.parameters()).device | |
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer | |
logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.") | |
load_success_flags["enhancement"] = True | |
except Exception as e: | |
logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False) | |
logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately | |
logger_load.warning("Enhancement features will be unavailable.") | |
# --- Load Separation Model --- | |
separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs" | |
logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---") | |
try: | |
logger_load.info(f"Attempting load on device: {DEVICE}") | |
# This automatically handles downloading the model checkpoint via demucs package | |
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE) | |
model_device = next(separator.parameters()).device | |
separation_models[SEPARATION_MODEL_KEY] = separator | |
logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.") | |
logger_load.info(f"Separation model available sources: {separator.sources}") | |
load_success_flags["separation"] = True | |
except Exception as e: | |
logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False) | |
logger_load.error(f"Traceback: {traceback.format_exc()}") | |
logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs). Check resource limits (RAM).") | |
logger_load.warning("Separation features will be unavailable.") | |
logger_load.info(f"--- Model loading attempts finished ---") | |
logger_load.info(f"Enhancement Model Loaded: {load_success_flags['enhancement']}") | |
logger_load.info(f"Separation Model Loaded: {load_success_flags['separation']}") | |
# --- FastAPI App --- | |
app = FastAPI( | |
title="AI Audio Editor API", | |
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.", | |
version="2.1.2", # Incremented version | |
) | |
async def startup_event(): | |
# Use the init logger for startup messages | |
logger_init.info("--- FastAPI Application Startup ---") | |
if AI_LIBS_AVAILABLE: | |
logger_init.info("AI Libraries imported successfully. Loading models in background thread...") | |
# Run blocking model load in thread | |
await asyncio.to_thread(load_hf_models) | |
logger_init.info("Background model loading task finished (check ModelLoader logs above for details).") | |
else: | |
logger_init.error("AI Libraries failed to import during init. AI features will be disabled.") | |
logger_init.info("--- Startup sequence complete ---") | |
# --- API Endpoints --- | |
def read_root(): | |
"""Root endpoint providing a welcome message and status of loaded models.""" | |
features = ["/trim", "/concat", "/volume", "/convert"] | |
ai_features_status = {} | |
if AI_LIBS_AVAILABLE: | |
if enhancement_models: | |
ai_features_status[ENHANCEMENT_MODEL_KEY] = "Loaded" | |
else: | |
ai_features_status[ENHANCEMENT_MODEL_KEY] = "Failed to load (check startup logs)" | |
if separation_models: | |
model = separation_models.get(SEPARATION_MODEL_KEY) | |
sources_str = ', '.join(model.sources) if model else 'N/A' | |
ai_features_status[SEPARATION_MODEL_KEY] = f"Loaded (Sources: {sources_str})" | |
else: | |
ai_features_status[SEPARATION_MODEL_KEY] = "Failed to load (check startup logs)" | |
else: | |
ai_features_status["AI Status"] = "Libraries Failed Import" | |
return { | |
"message": "Welcome to the AI Audio Editor API.", | |
"status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed", | |
"ai_models_status": ai_features_status, | |
"basic_endpoints": features, | |
"notes": "Requires FFmpeg. AI features require successful model loading at startup." | |
} | |
# --- Basic Editing Endpoints --- | |
async def trim_audio( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Audio file to trim."), | |
start_ms: int = Form(..., ge=0, description="Start time in milliseconds."), | |
end_ms: int = Form(..., gt=0, description="End time in milliseconds.") # Ensure end > 0 | |
): | |
"""Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub.""" | |
if end_ms <= start_ms: | |
raise HTTPException(status_code=422, detail="End time (end_ms) must be greater than start time (start_ms).") | |
logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms") | |
input_path = await save_upload_file(file, prefix="trim_in_") | |
# Schedule cleanup immediately after saving, even if loading fails later | |
background_tasks.add_task(cleanup_file, input_path) | |
output_path = None # Define before try block | |
try: | |
audio = load_audio_pydub(input_path) # Can raise HTTPException | |
trimmed_audio = audio[start_ms:end_ms] | |
logger.info(f"Audio trimmed to {len(trimmed_audio)}ms") | |
# Determine original format for export | |
original_format = os.path.splitext(file.filename)[1][1:].lower() | |
# Use mp3 as default only if no extension or if it's 'tmp' etc. | |
if not original_format or len(original_format) > 5: # Basic check for valid extension length | |
original_format = "mp3" | |
logger.warning(f"Using default export format 'mp3' for input '{file.filename}'") | |
output_path = export_audio_pydub(trimmed_audio, original_format) # Can raise HTTPException | |
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup | |
# Create a more informative filename | |
output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}" | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{original_format}", # Best guess for media type | |
filename=output_filename | |
) | |
except HTTPException as http_exc: | |
# If load/export raised HTTPException, re-raise it | |
# Cleanup might have already been scheduled, background tasks handle errors | |
logger.error(f"HTTP Exception during trim: {http_exc.detail}") | |
if output_path: cleanup_file(output_path) # Try immediate cleanup if output exists | |
raise http_exc | |
except Exception as e: | |
# Catch other unexpected errors during trimming logic | |
logger.error(f"Unexpected error during trim operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during trimming: {str(e)}") | |
async def concatenate_audio( | |
background_tasks: BackgroundTasks, | |
files: List[UploadFile] = File(..., description="Two or more audio files to join in order."), | |
output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').") | |
): | |
"""Concatenates two or more audio files sequentially using Pydub.""" | |
if len(files) < 2: | |
raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.") | |
logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'") | |
input_paths = [] # Keep track of all saved input file paths | |
output_path = None # Define before try block | |
try: | |
combined_audio: Optional[AudioSegment] = None | |
for i, file in enumerate(files): | |
if not file or not file.filename: | |
logger.warning(f"Skipping invalid file upload at index {i}.") | |
continue # Skip potentially empty file entries | |
input_path = await save_upload_file(file, prefix=f"concat_{i}_in_") | |
input_paths.append(input_path) | |
# Schedule cleanup for this specific input file immediately | |
background_tasks.add_task(cleanup_file, input_path) | |
try: | |
audio = load_audio_pydub(input_path) | |
if combined_audio is None: | |
combined_audio = audio | |
logger.info(f"Starting concatenation with '{file.filename}' ({len(combined_audio)}ms)") | |
else: | |
logger.info(f"Adding '{file.filename}' ({len(audio)}ms)") | |
combined_audio += audio | |
except HTTPException as load_exc: | |
# Log error but continue trying to load other files if possible | |
logger.error(f"Failed to load file '{file.filename}' for concatenation: {load_exc.detail}. Skipping this file.") | |
except Exception as load_exc: | |
logger.error(f"Unexpected error loading file '{file.filename}' for concatenation: {load_exc}. Skipping this file.", exc_info=True) | |
if combined_audio is None: | |
raise HTTPException(status_code=400, detail="No valid audio files could be loaded and combined.") | |
logger.info(f"Final concatenated audio length: {len(combined_audio)}ms") | |
output_path = export_audio_pydub(combined_audio, output_format) # Can raise HTTPException | |
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup | |
# Determine a reasonable output filename | |
first_valid_filename = files[0].filename if files and files[0] else "audio" | |
first_filename_base = os.path.splitext(first_valid_filename)[0] | |
output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}" | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{output_format}", | |
filename=output_filename | |
) | |
except HTTPException as http_exc: | |
# If load/export raised HTTPException, re-raise it | |
logger.error(f"HTTP Exception during concat: {http_exc.detail}") | |
# Cleanup for output path, inputs are handled by background tasks | |
if output_path: cleanup_file(output_path) | |
raise http_exc | |
except Exception as e: | |
# Catch other unexpected errors during combining logic | |
logger.error(f"Unexpected error during concat operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during concatenation: {str(e)}") | |
async def change_volume( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Audio file to adjust volume for."), | |
change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.") | |
): | |
"""Adjusts the volume of an audio file by a specified decibel amount using Pydub.""" | |
logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB") | |
input_path = await save_upload_file(file, prefix="volume_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
output_path = None | |
try: | |
audio = load_audio_pydub(input_path) | |
# Check for potential silence before applying gain | |
if audio.dBFS == -float('inf'): | |
logger.warning(f"Input file '{file.filename}' appears to be silent. Applying volume change may have no effect.") | |
adjusted_audio = audio + change_db | |
logger.info(f"Volume adjusted by {change_db}dB.") | |
original_format = os.path.splitext(file.filename)[1][1:].lower() | |
if not original_format or len(original_format) > 5: original_format = "mp3" | |
output_path = export_audio_pydub(adjusted_audio, original_format) | |
background_tasks.add_task(cleanup_file, output_path) | |
# Create filename | |
sign = "+" if change_db >= 0 else "" | |
output_filename=f"volume_{sign}{change_db}dB_{os.path.splitext(file.filename)[0]}.{original_format}" | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{original_format}", | |
filename=output_filename | |
) | |
except HTTPException as http_exc: | |
logger.error(f"HTTP Exception during volume change: {http_exc.detail}") | |
if output_path: cleanup_file(output_path) | |
raise http_exc | |
except Exception as e: | |
logger.error(f"Unexpected error during volume operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during volume adjustment: {str(e)}") | |
async def convert_format( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Audio file to convert."), | |
output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').") | |
): | |
"""Converts an audio file to a different format using Pydub.""" | |
# Define allowed formats explicitly | |
allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus', 'wma', 'aiff'} # Expanded list | |
output_format_lower = output_format.lower() | |
if output_format_lower not in allowed_formats: | |
raise HTTPException(status_code=422, detail=f"Invalid output format '{output_format}'. Allowed: {', '.join(sorted(list(allowed_formats)))}") | |
logger.info(f"Convert request: file='{file.filename}', output_format='{output_format_lower}'") | |
input_path = await save_upload_file(file, prefix="convert_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
output_path = None | |
try: | |
# Load using pydub, which handles many input formats | |
audio = load_audio_pydub(input_path) | |
logger.info(f"Successfully loaded '{file.filename}' for conversion.") | |
# Export using pydub | |
output_path = export_audio_pydub(audio, output_format_lower) | |
background_tasks.add_task(cleanup_file, output_path) | |
logger.info(f"Successfully exported to {output_format_lower}") | |
# Construct new filename | |
filename_base = os.path.splitext(file.filename)[0] | |
output_filename = f"{filename_base}_converted.{output_format_lower}" | |
# Determine media type (MIME type) - might need refinement for less common types | |
media_type_map = { | |
'mp3': 'audio/mpeg', 'wav': 'audio/wav', 'ogg': 'audio/ogg', | |
'flac': 'audio/flac', 'aac': 'audio/aac', 'm4a': 'audio/mp4', # m4a often uses mp4 container | |
'opus': 'audio/opus', 'wma':'audio/x-ms-wma', 'aiff':'audio/aiff' | |
} | |
media_type = media_type_map.get(output_format_lower, 'application/octet-stream') # Default binary if unknown | |
return FileResponse( | |
path=output_path, | |
media_type=media_type, | |
filename=output_filename | |
) | |
except HTTPException as http_exc: | |
logger.error(f"HTTP Exception during conversion: {http_exc.detail}") | |
if output_path: cleanup_file(output_path) | |
raise http_exc | |
except Exception as e: | |
logger.error(f"Unexpected error during convert operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during format conversion: {str(e)}") | |
# --- AI Endpoints --- | |
async def enhance_speech( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."), | |
# Keep model_key optional for now, assumes default if only one loaded | |
model_key: Optional[str] = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use (defaults to primary)."), | |
output_format: str = Form("wav", description="Output format (wav, flac recommended).") | |
): | |
"""Enhances speech audio using a pre-loaded SpeechBrain model.""" | |
if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.") | |
# Use the provided key or the default | |
actual_model_key = model_key or ENHANCEMENT_MODEL_KEY | |
if actual_model_key not in enhancement_models: | |
logger.error(f"Enhancement model key '{actual_model_key}' requested but model not loaded.") | |
raise HTTPException(status_code=503, detail=f"Enhancement model '{actual_model_key}' is not loaded or available. Check server startup logs.") | |
loaded_model = enhancement_models[actual_model_key] | |
logger.info(f"Enhance request: file='{file.filename}', model='{actual_model_key}', format='{output_format}'") | |
input_path = await save_upload_file(file, prefix="enhance_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
output_path = None | |
try: | |
# Load audio as tensor, ensure correct SR (16kHz) | |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR) | |
logger.info("Submitting enhancement task to background thread...") | |
enhanced_audio_tensor = await asyncio.to_thread( | |
_run_enhancement_sync, loaded_model, audio_tensor, current_sr | |
) | |
logger.info("Enhancement task completed.") | |
# Save the result (tensor output from enhancer at 16kHz) | |
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format) | |
background_tasks.add_task(cleanup_file, output_path) | |
output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}" | |
media_type = f"audio/{output_format}" # Basic media type | |
return FileResponse(path=output_path, media_type=media_type, filename=output_filename) | |
except HTTPException as http_exc: | |
logger.error(f"HTTP Exception during enhancement: {http_exc.detail}") | |
if output_path: cleanup_file(output_path) | |
raise http_exc | |
except Exception as e: | |
logger.error(f"Unexpected error during enhancement operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during enhancement: {str(e)}") | |
async def separate_sources( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Music audio file to separate into stems."), | |
model_key: Optional[str] = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use (defaults to primary)."), | |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."), | |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).") | |
): | |
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive.""" | |
if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.") | |
actual_model_key = model_key or SEPARATION_MODEL_KEY | |
if actual_model_key not in separation_models: | |
logger.error(f"Separation model key '{actual_model_key}' requested but model not loaded.") | |
raise HTTPException(status_code=503, detail=f"Separation model '{actual_model_key}' is not loaded or available. Check server startup logs.") | |
loaded_model = separation_models[actual_model_key] | |
valid_stems = set(loaded_model.sources) | |
requested_stems = set(s.lower() for s in stems) | |
# Check if *any* requested stem is valid | |
if not requested_stems: | |
raise HTTPException(status_code=422, detail="No stems requested for separation.") | |
# Check if *all* requested stems are valid for this model | |
invalid_stems = requested_stems - valid_stems | |
if invalid_stems: | |
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested: {', '.join(invalid_stems)}. Model '{actual_model_key}' provides: {', '.join(valid_stems)}") | |
logger.info(f"Separate request: file='{file.filename}', model='{actual_model_key}', stems={requested_stems}, format='{output_format}'") | |
input_path = await save_upload_file(file, prefix="separate_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
stem_output_paths: Dict[str, str] = {} # Store paths of successfully saved stems | |
zip_buffer = io.BytesIO(); zipf = None # Initialize zip buffer and file object | |
try: | |
# Load audio as tensor, ensure correct SR (Demucs default 44.1kHz) | |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR) | |
logger.info("Submitting separation task to background thread...") | |
all_separated_stems_tensors = await asyncio.to_thread( | |
_run_separation_sync, loaded_model, audio_tensor, current_sr | |
) | |
logger.info("Separation task completed successfully.") | |
# --- Create ZIP file in memory --- | |
logger.info("Creating ZIP archive in memory...") | |
zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) | |
files_added_to_zip = 0 | |
for stem_name in requested_stems: | |
if stem_name in all_separated_stems_tensors: | |
stem_tensor = all_separated_stems_tensors[stem_name] | |
stem_path = None # Define stem_path before inner try | |
try: | |
# Save stem temporarily (save_hf_audio handles tensor) | |
# Use the model's native sampling rate for output (DEMUCS_SR) | |
stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format) | |
stem_output_paths[stem_name] = stem_path | |
# Schedule cleanup AFTER zip is potentially sent | |
background_tasks.add_task(cleanup_file, stem_path) | |
# Use a simpler archive name within the zip | |
archive_name = f"{stem_name}.{output_format}" | |
zipf.write(stem_path, arcname=archive_name) | |
files_added_to_zip += 1 | |
logger.info(f"Added '{archive_name}' to ZIP.") | |
except Exception as save_err: | |
# Log error saving/zipping this stem but continue with others | |
logger.error(f"Failed to save or add stem '{stem_name}' to zip: {save_err}", exc_info=True) | |
if stem_path: cleanup_file(stem_path) # Clean up if saved but couldn't zip | |
else: | |
# This case should be prevented by the earlier validation | |
logger.warning(f"Requested stem '{stem_name}' not found in model output (validation error?).") | |
zipf.close() # Close zip file BEFORE seeking/reading | |
zipf = None # Clear variable to indicate closed | |
if files_added_to_zip == 0: | |
logger.error("Failed to add any requested stems to the ZIP archive.") | |
raise HTTPException(status_code=500, detail="Failed to generate any of the requested stems.") | |
zip_buffer.seek(0) # Rewind buffer pointer for reading | |
# Create final ZIP filename | |
zip_filename = f"separated_{actual_model_key}_{os.path.splitext(file.filename)[0]}.zip" | |
logger.info(f"Sending ZIP file: {zip_filename}") | |
return StreamingResponse( | |
iter([zip_buffer.getvalue()]), # StreamingResponse needs an iterator | |
media_type="application/zip", | |
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'} | |
) | |
except HTTPException as http_exc: | |
logger.error(f"HTTP Exception during separation: {http_exc.detail}") | |
if zipf: zipf.close() # Ensure zipfile is closed | |
if zip_buffer: zip_buffer.close() | |
for path in stem_output_paths.values(): cleanup_file(path) # Cleanup successful stems | |
raise http_exc | |
except Exception as e: | |
logger.error(f"Unexpected error during separation operation: {e}", exc_info=True) | |
if zipf: zipf.close() | |
if zip_buffer: zip_buffer.close() | |
for path in stem_output_paths.values(): cleanup_file(path) | |
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during separation: {str(e)}") | |
finally: | |
# Ensure buffer is closed if not already done | |
if zip_buffer and not zip_buffer.closed: | |
zip_buffer.close() | |
# --- How to Run --- | |
# 1. Ensure FFmpeg is installed and accessible in your PATH. | |
# 2. Save this code as `app.py`. | |
# 3. Create `requirements.txt` (including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs, python-multipart, protobuf). | |
# 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!). | |
# 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Use port 7860 for HF Spaces default, remove --reload for production). | |
# | |
# --- WARNING --- | |
# - AI models require SIGNIFICANT RAM (often 8GB+) and CPU/GPU. Inference can be SLOW (minutes). Free HF Spaces might time out or lack resources. | |
# - First run downloads models (can take a long time/lots of disk space). | |
# - Ensure model names (e.g., "htdemucs") are correct. | |
# - MONITOR STARTUP LOGS carefully for model loading success/failure. Errors here will cause 503 errors later. |