Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
import tempfile | |
import logging | |
import asyncio | |
from typing import List, Optional, Dict, Any | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query | |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
import io # For zip file in memory | |
import zipfile | |
# --- Basic Editing Imports --- | |
from pydub import AudioSegment | |
from pydub.exceptions import CouldntDecodeError | |
# --- AI & Advanced Audio Imports --- | |
try: | |
import torch | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible | |
# Specific model imports might be needed depending on the chosen approach | |
import soundfile as sf | |
import numpy as np | |
import librosa # For resampling if needed | |
print("AI and advanced audio libraries loaded.") | |
except ImportError as e: | |
print(f"Error importing AI/Audio libraries: {e}") | |
print("Ensure torch, transformers, soundfile, librosa are installed.") | |
print("AI features will be unavailable.") | |
torch = None | |
pipeline = None | |
sf = None | |
np = None | |
librosa = None | |
# --- Configuration & Setup --- | |
TEMP_DIR = tempfile.gettempdir() | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Global Variables for Loaded Models --- | |
# Use dictionaries to potentially hold multiple models of each type later | |
enhancement_pipelines: Dict[str, Any] = {} | |
separation_models: Dict[str, Any] = {} # Might store pipeline or model/processor pair | |
# Target sampling rates for models (check model cards on Hugging Face!) | |
ENHANCEMENT_SR = 16000 # Example for speechbrain/sepformer | |
DEMUCS_SR = 44100 # Demucs default | |
# --- Helper Functions --- | |
def cleanup_file(file_path: str): | |
"""Safely remove a file.""" | |
try: | |
if file_path and os.path.exists(file_path): | |
os.remove(file_path) | |
logger.info(f"Cleaned up temporary file: {file_path}") | |
except Exception as e: | |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False) | |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str: | |
"""Saves an uploaded file to a temporary location and returns the path.""" | |
# Generate a unique temporary file path | |
_, file_extension = os.path.splitext(upload_file.filename) | |
if not file_extension: file_extension = ".wav" # Default if no extension | |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}") | |
try: | |
with open(temp_file_path, "wb") as buffer: | |
while content := await upload_file.read(1024 * 1024): buffer.write(content) | |
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}") | |
return temp_file_path | |
except Exception as e: | |
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True) | |
cleanup_file(temp_file_path) | |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}") | |
finally: | |
await upload_file.close() | |
# --- Audio Loading/Saving for AI Models --- | |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]: | |
"""Loads audio using soundfile, converts to mono float32, optionally resamples.""" | |
try: | |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False) | |
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}") | |
# Convert to mono if stereo | |
if audio.ndim > 1 and audio.shape[1] > 1: | |
# Simple averaging for mono conversion | |
audio = np.mean(audio, axis=1) | |
logger.info("Converted audio to mono") | |
# Resample if necessary | |
if target_sr and orig_sr != target_sr: | |
if librosa is None: | |
raise RuntimeError("Librosa is required for resampling but not installed.") | |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...") | |
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) | |
logger.info(f"Resampled audio shape: {audio.shape}") | |
current_sr = target_sr | |
else: | |
current_sr = orig_sr | |
return audio, current_sr | |
except Exception as e: | |
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True) | |
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.") | |
def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str: | |
"""Saves a NumPy audio array to a temporary file.""" | |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}" | |
output_path = os.path.join(TEMP_DIR, output_filename) | |
try: | |
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})") | |
# Ensure data is float32 for common formats like wav/flac, pydub handles mp3 etc. | |
if audio_data.dtype != np.float32: | |
audio_data = audio_data.astype(np.float32) | |
# Use soundfile for lossless formats | |
if output_format.lower() in ['wav', 'flac']: | |
sf.write(output_path, audio_data, sampling_rate, format=output_format.upper()) | |
else: | |
# For lossy formats like mp3, use pydub after converting numpy array | |
# Convert numpy array [-1.0, 1.0] float32 to pydub segment | |
# Scale to 16-bit integer range for pydub if needed | |
audio_int16 = (audio_data * 32767).astype(np.int16) | |
segment = AudioSegment( | |
audio_int16.tobytes(), | |
frame_rate=sampling_rate, | |
sample_width=audio_int16.dtype.itemsize, | |
channels=1 # Assuming mono output from AI models for now | |
) | |
segment.export(output_path, format=output_format) | |
return output_path | |
except Exception as e: | |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True) | |
cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail="Failed to save processed audio.") | |
# --- Synchronous AI Inference Functions (to be run in threads) --- | |
def _run_enhancement_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray: | |
"""Synchronous wrapper for enhancement model inference.""" | |
if not model_pipeline: raise ValueError("Enhancement model not loaded") | |
try: | |
logger.info(f"Running speech enhancement (input shape: {audio_data.shape}, SR: {sampling_rate})...") | |
# Pipeline usage depends heavily on the specific pipeline | |
# Example for a hypothetical 'audio-enhancement' pipeline: | |
result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate}) | |
enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output | |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})") | |
return enhanced_audio | |
except Exception as e: | |
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True) | |
raise # Re-raise to be caught by the async wrapper | |
def _run_separation_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]: | |
"""Synchronous wrapper for source separation model inference.""" | |
if not model_pipeline: raise ValueError("Separation model not loaded") | |
try: | |
logger.info(f"Running source separation (input shape: {audio_data.shape}, SR: {sampling_rate})...") | |
# Usage depends on the separation model/pipeline | |
# Example for a hypothetical 'audio-source-separation' pipeline: | |
# Note: Actual Demucs might need different handling (e.g., direct model call) | |
# result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate}) | |
# Manual example closer to raw Demucs model (if not using pipeline) | |
# Assuming `separation_models['demucs']` holds the loaded Demucs model instance | |
model = separation_models.get('demucs') | |
if not model: raise ValueError("Demucs model not loaded correctly") | |
# Demucs expects stereo input in shape (batch, channels, samples) | |
# Convert mono to stereo if needed, add batch dim | |
if audio_data.ndim == 1: | |
audio_data = np.stack([audio_data, audio_data], axis=0) # Create stereo from mono | |
audio_tensor = torch.tensor(audio_data).unsqueeze(0) # Add batch dimension | |
# Move to GPU if available and model is on GPU | |
device = next(model.parameters()).device | |
audio_tensor = audio_tensor.to(device) | |
with torch.no_grad(): | |
sources = model(audio_tensor)[0] # Output shape (stems, channels, samples) | |
# Detach, move to CPU, convert to numpy | |
sources_np = sources.detach().cpu().numpy() | |
# Convert back to mono for simplicity (average channels) | |
stems = { | |
'drums': np.mean(sources_np[0], axis=0), | |
'bass': np.mean(sources_np[1], axis=0), | |
'other': np.mean(sources_np[2], axis=0), | |
'vocals': np.mean(sources_np[3], axis=0), | |
} | |
# Important: The order (drums, bass, other, vocals) is specific to Demucs v3/v4 default model | |
logger.info(f"Separation complete. Found stems: {list(stems.keys())}") | |
return stems | |
except Exception as e: | |
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True) | |
raise | |
# --- Model Loading Function --- | |
def load_hf_models(): | |
"""Loads Hugging Face models at startup.""" | |
global enhancement_pipelines, separation_models | |
if torch is None or pipeline is None: | |
logger.warning("Torch or Transformers not available. Skipping Hugging Face model loading.") | |
return | |
# --- Load Enhancement Model --- | |
# Using speechbrain/sepformer-whamr-enhancement via pipeline (check HF for exact pipeline task) | |
# Or load model/processor manually if no direct pipeline exists | |
enhancement_model_id = "speechbrain/sepformer-whamr-enhancement" # Example ID | |
try: | |
logger.info(f"Loading enhancement model: {enhancement_model_id}...") | |
# Use appropriate task, might be 'audio-enhancement', 'audio-classification' with custom logic, or manual loading | |
# If no pipeline, load manually: | |
# enhancement_processor = AutoProcessor.from_pretrained(...) | |
# enhancement_model = AutoModel...from_pretrained(...) | |
# enhancement_pipelines['speechbrain_sepformer'] = {"processor": enhancement_processor, "model": enhancement_model} | |
# For now, let's assume a placeholder pipeline exists or skip if complex | |
# enhancement_pipelines['speechbrain_sepformer'] = pipeline("audio-enhancement", model=enhancement_model_id) | |
logger.warning(f"Skipping load for {enhancement_model_id} - requires specific pipeline or manual setup.") | |
except Exception as e: | |
logger.error(f"Failed to load enhancement model '{enhancement_model_id}': {e}", exc_info=False) | |
# --- Load Separation Model (Demucs) --- | |
# Demucs is often used directly, not via a standard HF pipeline task | |
separation_model_id = "facebook/demucs" # Or specific variant like facebook/hybrid_demucs | |
try: | |
logger.info(f"Loading separation model: {separation_model_id}...") | |
# Demucs usually requires loading the model directly | |
# Using AutoModel might work for some variants if configured correctly in HF hub | |
# separation_models['demucs'] = AutoModel.from_pretrained(separation_model_id) # Check if this works | |
# More typically, you might need to install the 'demucs' package itself: pip install -U demucs | |
# import demucs.separate | |
# model = demucs.apply.load_model(separation_model_id or demucs.pretrained.DEFAULT_MODEL) # Using demucs package | |
# separation_models['demucs'] = model | |
# For now, simulate loading failure as direct AutoModel might not work | |
raise NotImplementedError("Demucs loading typically requires the 'demucs' package or specific manual loading.") | |
logger.info(f"Separation model '{separation_model_id}' loaded.") | |
except Exception as e: | |
logger.error(f"Failed to load separation model '{separation_model_id}': {e}", exc_info=False) | |
logger.warning("Note: Demucs loading often requires 'pip install demucs' and specific loading code, not just AutoModel.") | |
# --- FastAPI App and Endpoints --- | |
app = FastAPI( | |
title="AI Audio Editor API", | |
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and HF model dependencies.", | |
version="2.0.0", | |
) | |
async def startup_event(): | |
"""Load models when the application starts.""" | |
logger.info("Application startup: Loading AI models...") | |
# Running model loading in a separate thread to avoid blocking startup completely | |
# although startup will wait for this thread to finish. | |
# Consider truly background loading if startup time is critical. | |
await asyncio.to_thread(load_hf_models) | |
logger.info("Model loading process finished (check logs for success/failure).") | |
# --- Basic Editing Endpoints (Mostly Unchanged) --- | |
def read_root(): | |
"""Root endpoint providing a welcome message and available features.""" | |
features = ["/trim", "/concat", "/volume", "/convert"] | |
ai_features = [] | |
if enhancement_pipelines: ai_features.append("/enhance") | |
if separation_models: ai_features.append("/separate") | |
return { | |
"message": "Welcome to the AI Audio Editor API.", | |
"basic_features": features, | |
"ai_features": ai_features if ai_features else "None available (models might have failed to load)", | |
"notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)." | |
} | |
# /trim, /concat, /volume, /convert endpoints remain largely the same as before | |
# Ensure they use the updated save_upload_file and cleanup logic | |
# (Code for these endpoints omitted for brevity - refer to previous example) | |
# ... Add /trim, /concat, /volume, /convert endpoints here ... | |
# --- AI Endpoints --- | |
async def enhance_speech( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."), | |
model_id: str = Query("speechbrain_sepformer", description="ID of the enhancement model to use (if multiple loaded)."), | |
output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).") | |
): | |
"""Enhances speech audio using a pre-loaded AI model (experimental).""" | |
if torch is None or sf is None or np is None: | |
raise HTTPException(status_code=501, detail="AI processing libraries not available.") | |
if model_id not in enhancement_pipelines: | |
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_id}' is not loaded or available.") | |
logger.info(f"Enhance request: file='{file.filename}', model='{model_id}', format='{output_format}'") | |
input_path = await save_upload_file(file, prefix="enhance_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
output_path = None # Define output_path before try block | |
try: | |
# Load audio, ensure correct SR for the model | |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR) | |
# Run inference in a separate thread | |
logger.info("Submitting enhancement task to background thread...") | |
model_pipeline = enhancement_pipelines[model_id] # Get the specific loaded pipeline/model | |
enhanced_audio = await asyncio.to_thread( | |
_run_enhancement_sync, model_pipeline, audio_data, current_sr | |
) | |
logger.info("Enhancement task completed.") | |
# Save the result | |
output_path = save_hf_audio(enhanced_audio, current_sr, output_format) # Use current_sr (which is target_sr) | |
background_tasks.add_task(cleanup_file, output_path) | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{output_format}", | |
filename=f"enhanced_{file.filename}" | |
) | |
except Exception as e: | |
logger.error(f"Error during enhancement operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) # Cleanup if error occurred after output started saving | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}") | |
async def separate_sources( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Music audio file to separate into stems."), | |
model_id: str = Query("demucs", description="ID of the separation model to use."), | |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."), | |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).") | |
): | |
"""Separates music into stems (vocals, drums, bass, other) using Demucs (experimental). Returns a ZIP archive.""" | |
if torch is None or sf is None or np is None: | |
raise HTTPException(status_code=501, detail="AI processing libraries not available.") | |
if model_id not in separation_models: | |
raise HTTPException(status_code=503, detail=f"Separation model '{model_id}' is not loaded or available.") | |
valid_stems = {'vocals', 'drums', 'bass', 'other'} | |
requested_stems = set(s.lower() for s in stems) | |
if not requested_stems.issubset(valid_stems): | |
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are: {', '.join(valid_stems)}") | |
logger.info(f"Separate request: file='{file.filename}', model='{model_id}', stems={requested_stems}, format='{output_format}'") | |
input_path = await save_upload_file(file, prefix="separate_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
stem_output_paths: Dict[str, str] = {} | |
zip_buffer = None | |
try: | |
# Load audio, ensure correct SR for the model (Demucs uses 44.1kHz) | |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR) | |
# Run inference in a separate thread | |
logger.info("Submitting separation task to background thread...") | |
model = separation_models[model_id] # Get the specific loaded model | |
all_separated_stems = await asyncio.to_thread( | |
_run_separation_sync, model, audio_data, current_sr | |
) | |
logger.info("Separation task completed.") | |
# --- Create ZIP file in memory --- | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
# Save only the requested stems | |
for stem_name in requested_stems: | |
if stem_name in all_separated_stems: | |
stem_data = all_separated_stems[stem_name] | |
# Save stem temporarily to disk first (needed for pydub/sf.write) | |
stem_path = save_hf_audio(stem_data, current_sr, output_format) | |
stem_output_paths[stem_name] = stem_path # Keep track for cleanup | |
background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup | |
# Add the saved stem file to the ZIP archive | |
archive_name = f"{stem_name}_{os.path.basename(input_path)}.{output_format}" | |
zipf.write(stem_path, arcname=archive_name) | |
logger.info(f"Added '{archive_name}' to ZIP.") | |
else: | |
logger.warning(f"Requested stem '{stem_name}' not found in model output.") | |
zip_buffer.seek(0) # Rewind buffer pointer | |
# Return the ZIP file | |
zip_filename = f"separated_stems_{os.path.splitext(file.filename)[0]}.zip" | |
return StreamingResponse( | |
zip_buffer, | |
media_type="application/zip", | |
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'} | |
) | |
except Exception as e: | |
logger.error(f"Error during separation operation: {e}", exc_info=True) | |
# Cleanup any stems that were saved before zipping failed | |
for path in stem_output_paths.values(): | |
cleanup_file(path) | |
if zip_buffer: zip_buffer.close() # Close memory buffer | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}") | |
# --- How to Run --- | |
# 1. Ensure FFmpeg is installed and accessible. | |
# 2. Save this code as `app.py`. | |
# 3. Create `requirements.txt` (as shown above). | |
# 4. Install dependencies: `pip install -r requirements.txt` (This can take time!) | |
# 5. Run the FastAPI server: `uvicorn app:app --reload --host 0.0.0.0` | |
# (Use --host 0.0.0.0 for external/Docker access. --reload is optional) | |
# | |
# --- WARNING --- | |
# - AI models require SIGNIFICANT RAM and CPU/GPU. Inference can be SLOW. | |
# - The first run will download models, which can take a long time and lots of disk space. | |
# - Ensure the specific model IDs used are correct and compatible with HF libraries. | |
# - Model loading at startup might fail if dependencies are missing or resources are insufficient. Check logs! | |
# | |
# --- Example Usage (using curl) --- | |
# | |
# **Enhance:** (Enhance noisy_speech.wav) | |
# curl -X POST "http://127.0.0.1:8000/enhance?model_id=speechbrain_sepformer" \ | |
# -F "file=@noisy_speech.wav" \ | |
# -F "output_format=wav" \ | |
# --output enhanced_speech.wav | |
# | |
# **Separate:** (Separate vocals and drums from music.mp3) | |
# curl -X POST "http://127.0.0.1:8000/separate?model_id=demucs" \ | |
# -F "[email protected]" \ | |
# -F "stems=vocals" \ | |
# -F "stems=drums" \ | |
# -F "output_format=mp3" \ | |
# --output separated_stems.zip |