Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
import tempfile | |
import logging | |
import asyncio | |
from typing import List, Optional, Dict, Any | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query | |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
import io | |
import zipfile | |
# --- Basic Editing Imports --- | |
from pydub import AudioSegment | |
from pydub.exceptions import CouldntDecodeError | |
# --- AI & Advanced Audio Imports --- | |
try: | |
import torch | |
# Transformers only needed if using HF pipelines directly, not for speechbrain/demucs manual loading | |
# from transformers import pipeline | |
import soundfile as sf | |
import numpy as np | |
import librosa | |
# Specific Model Libraries | |
import speechbrain.pretrained | |
import demucs.separate | |
import demucs.apply | |
print("AI and advanced audio libraries loaded.") | |
except ImportError as e: | |
print(f"Error importing AI/Audio libraries: {e}") | |
print("Ensure torch, soundfile, librosa, speechbrain, demucs are installed.") | |
print("AI features will be unavailable.") | |
torch = None | |
sf = None | |
np = None | |
librosa = None | |
speechbrain = None | |
demucs = None | |
# --- Configuration & Setup --- | |
TEMP_DIR = tempfile.gettempdir() | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Global Variables for Loaded Models --- | |
# Use consistent keys for storing/retrieving models | |
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer" | |
# Choose a default Demucs model (htdemucs is good quality) | |
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one | |
enhancement_models: Dict[str, Any] = {} | |
separation_models: Dict[str, Any] = {} | |
# Target sampling rates (confirm from model specifics if necessary) | |
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz | |
DEMUCS_SR = 44100 # Demucs default is 44.1kHz | |
# --- Device Selection --- | |
if torch: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {DEVICE}") | |
else: | |
DEVICE = "cpu" # Fallback if torch failed import | |
# --- Helper Functions (cleanup_file, save_upload_file - same as before) --- | |
def cleanup_file(file_path: str): | |
try: | |
if file_path and os.path.exists(file_path): | |
os.remove(file_path) | |
logger.info(f"Cleaned up temporary file: {file_path}") | |
except Exception as e: | |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False) | |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str: | |
_, file_extension = os.path.splitext(upload_file.filename) | |
if not file_extension: file_extension = ".wav" | |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}") | |
try: | |
with open(temp_file_path, "wb") as buffer: | |
while content := await upload_file.read(1024 * 1024): buffer.write(content) | |
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}") | |
return temp_file_path | |
except Exception as e: | |
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True) | |
cleanup_file(temp_file_path) | |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}") | |
finally: | |
await upload_file.close() | |
# --- Audio Loading/Saving for AI Models --- | |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]: | |
"""Loads audio, converts to mono float32 Torch tensor, optionally resamples.""" | |
try: | |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False) | |
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}") | |
if audio.ndim > 1 and audio.shape[0] > 5: # Check if likely stereo (more than 5 channels unlikely) | |
logger.warning(f"Detected {audio.shape[0]} channels, attempting to convert to mono by averaging.") | |
audio = np.mean(audio, axis=0) # Average channels if multi-channel | |
elif audio.ndim > 1: | |
audio = audio[0] # Take first channel if shape is like (1, N) | |
# Convert numpy array to torch tensor | |
audio_tensor = torch.from_numpy(audio).float() | |
# Resample if necessary using librosa then convert back to tensor | |
if target_sr and orig_sr != target_sr: | |
if librosa is None: raise RuntimeError("Librosa is required for resampling but not installed.") | |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...") | |
# Librosa works on numpy, so convert back temp. | |
audio_np = audio_tensor.numpy() | |
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr) | |
audio_tensor = torch.from_numpy(resampled_audio_np).float() | |
current_sr = target_sr | |
logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}") | |
else: | |
current_sr = orig_sr | |
# Ensure tensor is on the correct device | |
return audio_tensor.to(DEVICE), current_sr | |
except Exception as e: | |
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True) | |
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.") | |
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str: | |
"""Saves audio data (Tensor or NumPy array) to a temporary file.""" | |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}" | |
output_path = os.path.join(TEMP_DIR, output_filename) | |
try: | |
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})") | |
# Convert tensor to numpy array if needed | |
if isinstance(audio_data, torch.Tensor): | |
logger.debug("Converting output tensor to NumPy array.") | |
# Ensure tensor is on CPU before converting to numpy | |
audio_np = audio_data.detach().cpu().numpy() | |
elif isinstance(audio_data, np.ndarray): | |
audio_np = audio_data | |
else: | |
raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}") | |
# Ensure data is float32 | |
if audio_np.dtype != np.float32: | |
logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.") | |
audio_np = audio_np.astype(np.float32) | |
# Clip values to avoid potential issues with formats expecting [-1, 1] | |
audio_np = np.clip(audio_np, -1.0, 1.0) | |
# Use soundfile (preferred for wav/flac) | |
if output_format.lower() in ['wav', 'flac']: | |
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper()) | |
else: | |
# For lossy formats, use pydub | |
logger.debug(f"Using pydub to export to lossy format: {output_format}") | |
# Scale float32 [-1, 1] to int16 for pydub | |
audio_int16 = (audio_np * 32767).astype(np.int16) | |
# Create AudioSegment (assuming mono for now) | |
num_channels = 1 if audio_int16.ndim == 1 else audio_int16.shape[0] # Basic channel check | |
if num_channels > 1 : audio_int16=audio_int16[0] # Use first channel if > 1, needs better handling | |
segment = AudioSegment( | |
audio_int16.tobytes(), | |
frame_rate=sampling_rate, | |
sample_width=audio_int16.dtype.itemsize, | |
channels=1 # Forcing mono currently | |
) | |
segment.export(output_path, format=output_format) | |
return output_path | |
except Exception as e: | |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True) | |
cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail="Failed to save processed audio.") | |
# --- Synchronous AI Inference Functions --- | |
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor: | |
"""Synchronous wrapper for SpeechBrain enhancement model inference.""" | |
if not model: raise ValueError("Enhancement model not loaded") | |
try: | |
logger.info(f"Running speech enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...") | |
# SpeechBrain models usually take tensors directly | |
# Add batch dimension if needed (most SB models expect batch) | |
if audio_tensor.ndim == 1: | |
audio_tensor = audio_tensor.unsqueeze(0) | |
# Move tensor to the same device as the model | |
model_device = next(model.parameters()).device | |
audio_tensor = audio_tensor.to(model_device) | |
with torch.no_grad(): | |
# Use enhance_batch for batched input | |
enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device)) | |
# Remove batch dimension from output before returning | |
enhanced_audio = enhanced_tensor.squeeze(0).cpu() # Move back to CPU | |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})") | |
return enhanced_audio | |
except Exception as e: | |
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True) | |
raise | |
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]: | |
"""Synchronous wrapper for Demucs source separation model inference.""" | |
if not model: raise ValueError("Separation model not loaded") | |
if not demucs: raise RuntimeError("Demucs library not available") | |
try: | |
logger.info(f"Running source separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...") | |
# Demucs expects audio as (batch, channels, samples) | |
# Ensure input tensor is on the correct device | |
model_device = next(model.parameters()).device | |
audio_tensor = audio_tensor.to(model_device) | |
# Add batch and channel dimensions if mono | |
if audio_tensor.ndim == 1: | |
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N) | |
elif audio_tensor.ndim == 2: # Should not happen often if load_audio ensures mono tensor | |
logger.warning("Input tensor has 2 dims, assuming (batch, samples), adding channel dim.") | |
audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N) | |
# Ensure correct number of channels expected by the model (usually 2) | |
if audio_tensor.shape[1] != model.audio_channels: | |
logger.warning(f"Model expects {model.audio_channels} channels, input has {audio_tensor.shape[1]}. Repeating mono channel.") | |
audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1) # Repeat mono to match expected channels | |
logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}") | |
with torch.no_grad(): | |
# Use demucs.apply.apply_model which handles chunking etc. | |
# requires ref = audio_tensor.mean(0) # Average channels for reference | |
# sources = demucs.apply.apply_model(model, audio_tensor[0], device=model_device, shifts=1, split=True, overlap=0.25)[0] # Process first batch item | |
# OR direct model call if simpler: | |
sources = model(audio_tensor)[0] # Output shape (stems, channels, samples) - remove batch dim [0] | |
logger.debug(f"Raw separated sources tensor shape: {sources.shape}") # Should be (num_stems, channels, samples) | |
# Map stems based on the model's sources list | |
# Default for htdemucs: drums, bass, other, vocals | |
stem_map = {name: sources[i] for i, name in enumerate(model.sources)} | |
# Convert back to mono for simplicity (average channels) and move to CPU | |
output_stems = {} | |
for name, data in stem_map.items(): | |
output_stems[name] = data.mean(dim=0).detach().cpu() # Average channels, detach, move to CPU | |
logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}") | |
return output_stems | |
except Exception as e: | |
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True) | |
raise | |
# --- Model Loading Function --- | |
def load_hf_models(): | |
"""Loads AI models at startup using correct libraries.""" | |
global enhancement_models, separation_models | |
if torch is None or speechbrain is None or demucs is None: | |
logger.error("Core AI libraries (torch, speechbrain, demucs) not available. Skipping model loading.") | |
return | |
# --- Load Enhancement Model (SpeechBrain) --- | |
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement" | |
try: | |
logger.info(f"Loading enhancement model: {enhancement_model_hparams} (using SpeechBrain)...") | |
# Ensure SpeechBrain downloads to a writable location if needed (optional) | |
# savedir = os.path.join(TEMP_DIR, "speechbrain_models") | |
# os.makedirs(savedir, exist_ok=True) | |
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams( | |
source=enhancement_model_hparams, | |
# savedir=savedir, # Specify download dir if needed | |
run_opts={"device": DEVICE} # Pass device option | |
) | |
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer # Store with consistent key | |
logger.info(f"Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {DEVICE}.") | |
except Exception as e: | |
logger.error(f"Failed to load enhancement model '{enhancement_model_hparams}': {e}", exc_info=True) | |
# --- Load Separation Model (Demucs) --- | |
# Using a standard pretrained model name from the demucs package | |
separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs" or "mdx_extra_q" | |
try: | |
logger.info(f"Loading separation model: {separation_model_name} (using Demucs package)...") | |
# This automatically handles downloading the model checkpoint | |
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE) | |
separation_models[SEPARATION_MODEL_KEY] = separator # Store with consistent key | |
logger.info(f"Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {DEVICE}.") | |
except Exception as e: | |
logger.error(f"Failed to load separation model '{separation_model_name}': {e}", exc_info=True) | |
logger.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).") | |
# --- FastAPI App --- | |
app = FastAPI( | |
title="AI Audio Editor API", | |
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries (torch, speechbrain, demucs).", | |
version="2.1.0", # Incremented version | |
) | |
async def startup_event(): | |
logger.info("Application startup: Loading AI models...") | |
# Run blocking model load in thread | |
await asyncio.to_thread(load_hf_models) | |
logger.info("Model loading process finished (check logs for success/failure).") | |
# --- API Endpoints --- | |
def read_root(): | |
# ... (root endpoint remains the same) ... | |
features = ["/trim", "/concat", "/volume", "/convert"] | |
ai_features = [] | |
if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})") | |
if separation_models: ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY})") | |
return { | |
"message": "Welcome to the AI Audio Editor API.", | |
"basic_features": features, | |
"ai_features": ai_features if ai_features else "None available (check startup logs)", | |
"notes": "Requires FFmpeg. AI features require specific models loaded at startup." | |
} | |
# --- Basic Editing Endpoints --- | |
# (Add /trim, /concat, /volume, /convert endpoints here - same logic as before) | |
# Make sure they use the updated cleanup_file and save_upload_file helpers. | |
# ... | |
# --- AI Endpoints (Corrected) --- | |
async def enhance_speech( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."), | |
# Model ID is less relevant now if only one is loaded, but keep for future flexibility | |
model_key: str = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use."), | |
output_format: str = Form("wav", description="Output format (wav, flac recommended).") | |
): | |
"""Enhances speech audio using a pre-loaded SpeechBrain model.""" | |
if torch is None or speechbrain is None: | |
raise HTTPException(status_code=501, detail="AI processing libraries (torch, speechbrain) not available.") | |
if model_key not in enhancement_models: | |
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.") | |
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.") | |
loaded_model = enhancement_models[model_key] | |
logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'") | |
input_path = await save_upload_file(file, prefix="enhance_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
output_path = None | |
try: | |
# Load audio as tensor, ensure correct SR | |
# SpeechBrain Sepformer expects 16kHz | |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR) | |
logger.info("Submitting enhancement task to background thread...") | |
enhanced_audio_tensor = await asyncio.to_thread( | |
_run_enhancement_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now | |
) | |
logger.info("Enhancement task completed.") | |
# Save the result (tensor output from enhancer) | |
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format) # Save at model's SR | |
background_tasks.add_task(cleanup_file, output_path) | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{output_format}", | |
filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}" | |
) | |
except Exception as e: | |
logger.error(f"Error during enhancement operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}") | |
async def separate_sources( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Music audio file to separate into stems."), | |
model_key: str = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use."), | |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."), | |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).") | |
): | |
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive.""" | |
if torch is None or demucs is None: | |
raise HTTPException(status_code=501, detail="AI processing libraries (torch, demucs) not available.") | |
if model_key not in separation_models: | |
logger.error(f"Separation model key '{model_key}' requested but model not loaded.") | |
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.") | |
loaded_model = separation_models[model_key] | |
valid_stems = set(loaded_model.sources) # Get stems directly from loaded model | |
requested_stems = set(s.lower() for s in stems) | |
if not requested_stems.issubset(valid_stems): | |
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Model '{model_key}' provides: {', '.join(valid_stems)}") | |
logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'") | |
input_path = await save_upload_file(file, prefix="separate_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
stem_output_paths: Dict[str, str] = {} | |
zip_buffer = None | |
try: | |
# Load audio as tensor, ensure correct SR (Demucs default 44.1kHz) | |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR) | |
logger.info("Submitting separation task to background thread...") | |
all_separated_stems_tensors = await asyncio.to_thread( | |
_run_separation_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now | |
) | |
logger.info("Separation task completed.") | |
# --- Create ZIP file in memory --- | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
# Save only the requested stems | |
for stem_name in requested_stems: | |
if stem_name in all_separated_stems_tensors: | |
stem_tensor = all_separated_stems_tensors[stem_name] | |
# Save stem temporarily (save_hf_audio handles tensor) | |
# Use the model's native sampling rate for output | |
stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format) | |
stem_output_paths[stem_name] = stem_path | |
background_tasks.add_task(cleanup_file, stem_path) | |
archive_name = f"{stem_name}_{os.path.splitext(file.filename)[0]}.{output_format}" | |
zipf.write(stem_path, arcname=archive_name) | |
logger.info(f"Added '{archive_name}' to ZIP.") | |
else: | |
# This case should be prevented by the earlier validation | |
logger.warning(f"Requested stem '{stem_name}' not found in model output (should not happen).") | |
zip_buffer.seek(0) | |
zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip" | |
return StreamingResponse( | |
zip_buffer, | |
media_type="application/zip", | |
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'} | |
) | |
except Exception as e: | |
logger.error(f"Error during separation operation: {e}", exc_info=True) | |
for path in stem_output_paths.values(): cleanup_file(path) | |
if zip_buffer: zip_buffer.close() | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}") | |
# --- Add back the basic editing endpoints (/trim, /concat, /volume, /convert) here --- | |
# ... (Remember to include them) ... | |
# --- How to Run --- | |
# 1. Ensure FFmpeg is installed. | |
# 2. Save code as `app.py`. Create/update `requirements.txt`. | |
# 3. Install: `pip install -r requirements.txt` (May take significant time/space!) | |
# 4. Run: `uvicorn app:app --reload --host 0.0.0.0` |