Spaces:
Sleeping
Sleeping
# ----------- START app.py ----------- | |
import os | |
import uuid | |
import tempfile | |
import logging | |
import asyncio | |
from typing import List, Optional, Dict, Any | |
import io | |
import zipfile | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query | |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
# --- Basic Editing Imports --- | |
from pydub import AudioSegment | |
from pydub.exceptions import CouldntDecodeError | |
# --- AI & Advanced Audio Imports --- | |
try: | |
import torch | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible | |
# Specific model imports might be needed depending on the chosen approach | |
# E.g. for Demucs V4 (Hybrid Transformer): from demucs.hdemucs import HDemucs | |
# from demucs.pretrained import hdemucs_mmi | |
import soundfile as sf | |
import numpy as np | |
import librosa # For resampling if needed | |
AI_LIBRARIES_AVAILABLE = True | |
print("AI and advanced audio libraries loaded.") | |
except ImportError as e: | |
print(f"Warning: Error importing AI/Audio libraries: {e}") | |
print("Ensure torch, transformers, soundfile, librosa are installed.") | |
print("AI features will be unavailable.") | |
AI_LIBRARIES_AVAILABLE = False | |
# Define dummy placeholders if needed, or just rely on AI_LIBRARIES_AVAILABLE flag | |
torch = None | |
pipeline = None | |
sf = None | |
np = None | |
librosa = None | |
# --- Configuration & Setup --- | |
TEMP_DIR = tempfile.gettempdir() | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Global Variables for Loaded Models --- | |
# Use dictionaries to potentially hold multiple models of each type later | |
enhancement_models: Dict[str, Any] = {} # Store model/processor or pipeline | |
separation_models: Dict[str, Any] = {} # Store model/processor or pipeline | |
# Target sampling rates for models (check model cards on Hugging Face!) | |
# These MUST match the models being loaded in download_models.py and load_hf_models | |
ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement" | |
ENHANCEMENT_SR = 16000 # Sepformer uses 16kHz | |
# Note: facebook/demucs is deprecated in transformers >4.26. Use specific variants. | |
# Using facebook/htdemucs_ft for example (requires Demucs v4 style loading) | |
# Or find a model suitable for AutoModel if needed. | |
SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Example using a quantized version (smaller, faster CPU) | |
# SEPARATION_MODEL_ID = "facebook/hdemucs_mmi" # Example for Multi-Media Instructions model (if using demucs lib) | |
DEMUCS_SR = 44100 # Demucs default is 44.1kHz | |
# Define HF_HOME cache directory *within* the container if downloading during build | |
HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache") # Use HF_HOME from Dockerfile or default | |
# --- Helper Functions (cleanup_file, save_upload_file, load_audio_for_hf, save_hf_audio) --- | |
# (Include the helper functions from the previous app.py example here) | |
# ... | |
def cleanup_file(file_path: str): | |
"""Safely remove a file.""" | |
try: | |
if file_path and os.path.exists(file_path): | |
os.remove(file_path) | |
logger.info(f"Cleaned up temporary file: {file_path}") | |
except Exception as e: | |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False) | |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str: | |
"""Saves an uploaded file to a temporary location and returns the path.""" | |
_, file_extension = os.path.splitext(upload_file.filename) | |
if not file_extension: file_extension = ".wav" # Default if no extension | |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}") | |
try: | |
with open(temp_file_path, "wb") as buffer: | |
while content := await upload_file.read(1024 * 1024): buffer.write(content) | |
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}") | |
return temp_file_path | |
except Exception as e: | |
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True) | |
cleanup_file(temp_file_path) | |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}") | |
finally: | |
await upload_file.close() | |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]: | |
"""Loads audio using soundfile, converts to mono float32, optionally resamples.""" | |
if not AI_LIBRARIES_AVAILABLE or sf is None or np is None: | |
raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.") | |
try: | |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False) | |
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}") | |
if audio.ndim > 1 and audio.shape[-1] > 1: # Check last dimension for channels | |
if audio.shape[0] == min(audio.shape): # If channels are first dim | |
audio = audio.T # Transpose to (samples, channels) | |
audio = np.mean(audio, axis=1) | |
logger.info(f"Converted audio to mono, new shape: {audio.shape}") | |
elif audio.ndim > 1: # If shape is like (1, N) or (N, 1) | |
audio = audio.squeeze() # Remove singleton dimension | |
logger.info(f"Squeezed audio to 1D, new shape: {audio.shape}") | |
if target_sr and orig_sr != target_sr: | |
if librosa is None: | |
raise RuntimeError("Librosa is required for resampling but not installed.") | |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...") | |
# Ensure audio is contiguous before resampling if necessary | |
if not audio.flags['C_CONTIGUOUS']: | |
audio = np.ascontiguousarray(audio) | |
audio = librosa.resample(y=audio, orig_sr=orig_sr, target_sr=target_sr) | |
logger.info(f"Resampled audio shape: {audio.shape}") | |
current_sr = target_sr | |
else: | |
current_sr = orig_sr | |
return audio, current_sr | |
except Exception as e: | |
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True) | |
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.") | |
def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str: | |
"""Saves a NumPy audio array to a temporary file.""" | |
if not AI_LIBRARIES_AVAILABLE or sf is None or np is None: | |
raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.") | |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}" | |
output_path = os.path.join(TEMP_DIR, output_filename) | |
try: | |
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format}, shape={audio_data.shape})") | |
# Ensure data is float32 for common formats like wav/flac | |
if audio_data.dtype != np.float32: | |
logger.warning(f"Audio data has dtype {audio_data.dtype}, converting to float32.") | |
audio_data = audio_data.astype(np.float32) | |
# Clip data to avoid issues with some formats/players if values go beyond [-1, 1] | |
audio_data = np.clip(audio_data, -1.0, 1.0) | |
# Use soundfile for lossless formats | |
if output_format.lower() in ['wav', 'flac']: | |
sf.write(output_path, audio_data, sampling_rate, format=output_format.upper()) | |
else: | |
# For lossy formats like mp3, use pydub after converting numpy array | |
logger.debug("Using pydub for lossy format export...") | |
# Scale float32 [-1, 1] to int16 for pydub | |
audio_int16 = (audio_data * 32767).astype(np.int16) | |
if audio_int16.ndim > 1: # Should be mono by now, but double check | |
logger.warning("Audio data still has multiple dimensions before pydub export, attempting mean.") | |
audio_int16 = np.mean(audio_int16, axis=1).astype(np.int16) | |
segment = AudioSegment( | |
audio_int16.tobytes(), | |
frame_rate=sampling_rate, | |
sample_width=audio_int16.dtype.itemsize, | |
channels=1 # Assuming mono output from AI models for now | |
) | |
segment.export(output_path, format=output_format) | |
return output_path | |
except Exception as e: | |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True) | |
cleanup_file(output_path) | |
raise HTTPException(status_code=500, detail="Failed to save processed audio.") | |
# --- Synchronous AI Inference Functions (_run_enhancement_sync, _run_separation_sync) --- | |
# (Include the sync functions from the previous app.py example here) | |
# Make sure they handle potential model loading issues gracefully | |
# ... | |
def _run_enhancement_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray: | |
"""Synchronous wrapper for enhancement model inference.""" | |
if not AI_LIBRARIES_AVAILABLE or model_key not in enhancement_models: | |
raise ValueError(f"Enhancement model '{model_key}' not available or AI libraries missing.") | |
model_info = enhancement_models[model_key] | |
# Adapt based on whether model_info holds a pipeline or model/processor | |
# This example assumes a pipeline-like object is stored | |
enhancer = model_info # Assuming pipeline | |
if not enhancer: raise ValueError(f"Enhancement pipeline '{model_key}' is None.") | |
try: | |
logger.info(f"Running speech enhancement with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...") | |
# Usage depends heavily on the specific model/pipeline interface | |
# For SpeechBrain models often used *without* HF pipeline: | |
# Example: enhanced_wav = enhancer.enhance_batch(torch.tensor(audio_data).unsqueeze(0), lengths=torch.tensor([audio_data.shape[0]])) | |
# enhanced_audio = enhanced_wav.squeeze(0).cpu().numpy() | |
# If using a generic HF pipeline: | |
result = enhancer({"raw": audio_data, "sampling_rate": sampling_rate}) | |
enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output | |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})") | |
return enhanced_audio | |
except Exception as e: | |
logger.error(f"Error during synchronous enhancement inference with '{model_key}': {e}", exc_info=True) | |
raise # Re-raise to be caught by the async wrapper | |
def _run_separation_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]: | |
"""Synchronous wrapper for source separation model inference.""" | |
if not AI_LIBRARIES_AVAILABLE or model_key not in separation_models: | |
raise ValueError(f"Separation model '{model_key}' not available or AI libraries missing.") | |
model_info = separation_models[model_key] | |
model = model_info # Assuming direct model object is stored for Demucs | |
if not model: raise ValueError(f"Separation model '{model_key}' is None.") | |
try: | |
logger.info(f"Running source separation with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...") | |
# Prepare input tensor for Demucs-like models | |
# Expects (batch, channels, samples), float32 | |
if audio_data.ndim == 1: | |
# Need stereo for standard Demucs | |
logger.debug("Separation input is mono, duplicating to create stereo.") | |
audio_data = np.stack([audio_data, audio_data], axis=0) # (2, samples) | |
if audio_data.shape[0] != 2: | |
# If it's somehow (samples, 2), transpose | |
if audio_data.shape[1] == 2: audio_data = audio_data.T | |
else: raise ValueError(f"Unexpected input audio shape for separation: {audio_data.shape}") | |
audio_tensor = torch.tensor(audio_data, dtype=torch.float32).unsqueeze(0) # (1, 2, samples) | |
# Move to model's device (CPU or GPU) | |
device = next(model.parameters()).device | |
logger.debug(f"Moving separation tensor to device: {device}") | |
audio_tensor = audio_tensor.to(device) | |
# Perform inference | |
with torch.no_grad(): | |
logger.debug("Starting model inference for separation...") | |
# Output shape depends on model, e.g., (batch, stems, channels, samples) | |
sources = model(audio_tensor)[0] # Remove batch dim | |
logger.debug(f"Model inference complete, sources shape: {sources.shape}") | |
# Detach, move to CPU, convert to numpy | |
sources_np = sources.detach().cpu().numpy() # (stems, channels, samples) | |
# Define stem order based on the *specific* Demucs model used | |
# This order is for default Demucs v3/v4 (facebook/demucs, facebook/htdemucs_ft, etc.) | |
stem_names = ['drums', 'bass', 'other', 'vocals'] | |
if sources_np.shape[0] != len(stem_names): | |
logger.warning(f"Model output {sources_np.shape[0]} stems, expected {len(stem_names)}. Stem names might be incorrect.") | |
# Fallback names if shape mismatch | |
stem_names = [f"stem_{i+1}" for i in range(sources_np.shape[0])] | |
stems = {} | |
for i, name in enumerate(stem_names): | |
# Average channels to get mono stem | |
mono_stem = np.mean(sources_np[i], axis=0) | |
stems[name] = mono_stem | |
logger.debug(f"Extracted stem '{name}', shape: {mono_stem.shape}") | |
logger.info(f"Separation complete. Found stems: {list(stems.keys())}") | |
return stems | |
except Exception as e: | |
logger.error(f"Error during synchronous separation inference with '{model_key}': {e}", exc_info=True) | |
raise | |
# --- Model Loading Function --- | |
# (Include the load_hf_models function from the previous app.py example here) | |
# Make sure it uses the correct model IDs and potentially adjusts loading logic | |
# if using libraries like `demucs` directly. | |
# ... | |
def load_hf_models(): | |
"""Loads Hugging Face models at startup.""" | |
if not AI_LIBRARIES_AVAILABLE: | |
logger.warning("Skipping Hugging Face model loading as libraries are missing.") | |
return | |
global enhancement_models, separation_models | |
# --- Load Enhancement Model --- | |
enhancement_key = "speechbrain_enhancer" # Internal key | |
try: | |
logger.info(f"Attempting to load enhancement model: {ENHANCEMENT_MODEL_ID}...") | |
# SpeechBrain models often require specific loading from their toolkit or HF spaces | |
# This might involve cloning a repo or using specific classes. | |
# Using HF pipeline if available, otherwise manual load. | |
# Example using pipeline (might not work for all speechbrain models): | |
# enhancement_models[enhancement_key] = pipeline( | |
# "audio-enhancement", # Or appropriate task | |
# model=ENHANCEMENT_MODEL_ID, | |
# cache_dir=HF_CACHE_DIR, | |
# device=0 if torch.cuda.is_available() else -1 # Use GPU if possible | |
# ) | |
# Manual load might be needed: | |
# from speechbrain.pretrained import SepformerEnhancement | |
# enhancer = SepformerEnhancement.from_hparams( | |
# source=ENHANCEMENT_MODEL_ID, | |
# savedir=os.path.join(HF_CACHE_DIR, "speechbrain", ENHANCEMENT_MODEL_ID.split('/')[-1]), | |
# run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"} | |
# ) | |
# enhancement_models[enhancement_key] = enhancer | |
logger.warning(f"Actual loading for {ENHANCEMENT_MODEL_ID} skipped - requires SpeechBrain toolkit or specific pipeline setup.") | |
# To make the endpoint testable without full setup: | |
# enhancement_models[enhancement_key] = None # Or a dummy function | |
except Exception as e: | |
logger.error(f"Failed to load enhancement model '{ENHANCEMENT_MODEL_ID}': {e}", exc_info=False) | |
# --- Load Separation Model (Demucs) --- | |
separation_key = "demucs_separator" # Internal key | |
try: | |
logger.info(f"Attempting to load separation model: {SEPARATION_MODEL_ID}...") | |
# Loading Demucs models can be complex. | |
# Option 1: Use AutoModel if the HF Hub version supports it (less common for Demucs) | |
# Option 2: Use the `demucs` library (recommended if installed: pip install -U demucs) | |
# Option 3: Find a Transformers-compatible version if available. | |
# Example using AutoModel (Try this first, might work for some quantized/HF versions) | |
try: | |
# Determine device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Loading Demucs on device: {device}") | |
# Check if AutoModelForSpeechSeq2Seq is appropriate, might need a different AutoModel class | |
separation_models[separation_key] = AutoModelForSpeechSeq2Seq.from_pretrained( | |
SEPARATION_MODEL_ID, | |
cache_dir=HF_CACHE_DIR | |
# Add trust_remote_code=True if needed for custom model code on HF hub | |
).to(device) | |
# Check if the loaded model has an 'eval' method (common for PyTorch models) | |
if hasattr(separation_models[separation_key], 'eval'): | |
separation_models[separation_key].eval() # Set to evaluation mode | |
logger.info(f"Successfully loaded separation model '{SEPARATION_MODEL_ID}' using AutoModel.") | |
except Exception as auto_model_err: | |
logger.warning(f"Failed to load '{SEPARATION_MODEL_ID}' using AutoModel: {auto_model_err}. Consider installing 'demucs' library.") | |
separation_models[separation_key] = None # Ensure it's None if loading failed | |
# Example using `demucs` library (if installed) | |
# try: | |
# import demucs.separate | |
# model = demucs.apply.load_model(pretrained_model_path_or_url) # Needs adjustment | |
# separation_models[separation_key] = model | |
# logger.info(f"Successfully loaded separation model using 'demucs' library.") | |
# except ImportError: | |
# logger.error("Cannot load Demucs: 'demucs' library not found. Please run 'pip install -U demucs'.") | |
# except Exception as demucs_lib_err: | |
# logger.error(f"Error loading model using 'demucs' library: {demucs_lib_err}") | |
except Exception as e: | |
logger.error(f"General error loading separation model '{SEPARATION_MODEL_ID}': {e}", exc_info=False) | |
if separation_key in separation_models: separation_models[separation_key] = None | |
# --- FastAPI App and Endpoints --- | |
app = FastAPI( | |
title="AI Audio Editor API", | |
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and HF model dependencies.", | |
version="2.0.0", | |
) | |
async def startup_event(): | |
"""Load models when the application starts.""" | |
logger.info("Application startup: Loading AI models (this may take time)...") | |
await asyncio.to_thread(load_hf_models) | |
logger.info("Model loading process finished.") | |
# --- API Endpoints --- | |
# (Include / , /trim, /concat, /volume, /convert endpoints here - same as previous version) | |
# ... | |
def read_root(): | |
"""Root endpoint providing a welcome message and available features.""" | |
features = ["/trim", "/concat", "/volume", "/convert"] | |
ai_features = [] | |
# Check if models were successfully loaded (i.e., not None) | |
if any(model is not None for model in enhancement_models.values()): ai_features.append("/enhance") | |
if any(model is not None for model in separation_models.values()): ai_features.append("/separate") | |
return { | |
"message": "Welcome to the AI Audio Editor API.", | |
"basic_features": features, | |
"ai_features": ai_features if ai_features else "None loaded (check logs)", | |
"notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)." | |
} | |
async def trim_audio( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Audio file to trim."), | |
start_ms: int = Form(..., description="Start time in milliseconds."), | |
end_ms: int = Form(..., description="End time in milliseconds.") | |
): | |
"""Trims an audio file to the specified start and end times (in milliseconds).""" | |
if start_ms < 0 or end_ms <= start_ms: | |
raise HTTPException(status_code=422, detail="Invalid start/end times. Ensure start_ms >= 0 and end_ms > start_ms.") | |
logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms") | |
input_path = None | |
output_path = None | |
try: | |
input_path = await save_upload_file(file, prefix="trim_in_") | |
background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup | |
# Use Pydub for basic trim | |
audio = AudioSegment.from_file(input_path) | |
trimmed_audio = audio[start_ms:end_ms] | |
logger.info(f"Audio trimmed to {len(trimmed_audio)}ms") | |
original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3" | |
if not original_format or original_format == "tmp": original_format = "mp3" | |
output_filename = f"trimmed_{uuid.uuid4().hex}.{original_format}" | |
output_path = os.path.join(TEMP_DIR, output_filename) | |
trimmed_audio.export(output_path, format=original_format) | |
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{original_format}", # Attempt correct media type | |
filename=f"trimmed_{file.filename}" | |
) | |
except CouldntDecodeError: | |
logger.warning(f"pydub failed to decode: {file.filename}") | |
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.") | |
except Exception as e: | |
logger.error(f"Error during trim operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) # Immediate cleanup on error | |
if input_path: cleanup_file(input_path) # Immediate cleanup on error | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}") | |
async def concatenate_audio( | |
background_tasks: BackgroundTasks, | |
files: List[UploadFile] = File(..., description="Two or more audio files to join in order."), | |
output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').") | |
): | |
"""Concatenates two or more audio files sequentially.""" | |
if len(files) < 2: | |
raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.") | |
logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'") | |
input_paths = [] | |
loaded_audios = [] | |
output_path = None | |
try: | |
combined_audio = AudioSegment.empty() | |
first_filename_base = "combined" | |
for i, file in enumerate(files): | |
input_path = await save_upload_file(file, prefix=f"concat_{i}_") | |
input_paths.append(input_path) | |
background_tasks.add_task(cleanup_file, input_path) | |
audio = AudioSegment.from_file(input_path) | |
combined_audio += audio | |
if i == 0: first_filename_base = os.path.splitext(file.filename)[0] | |
logger.info(f"Appended '{file.filename}', current total duration: {len(combined_audio)}ms") | |
if len(combined_audio) == 0: | |
raise HTTPException(status_code=500, detail="No audio data after attempting concatenation.") | |
output_filename_final = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}" | |
output_path = os.path.join(TEMP_DIR, f"concat_out_{uuid.uuid4().hex}.{output_format}") | |
combined_audio.export(output_path, format=output_format) | |
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{output_format}", | |
filename=output_filename_final | |
) | |
except CouldntDecodeError as e: | |
logger.warning(f"pydub failed to decode one of the concat files: {e}") | |
raise HTTPException(status_code=415, detail=f"Unsupported format or corrupted file among inputs: {e}") | |
except Exception as e: | |
logger.error(f"Error during concat operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
for p in input_paths: cleanup_file(p) | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}") | |
async def change_volume( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Audio file to adjust volume for."), | |
change_db: float = Form(..., description="Volume change in decibels (dB). Positive values increase volume, negative values decrease.") | |
): | |
"""Adjusts the volume of an audio file by a specified decibel amount.""" | |
logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB") | |
input_path = None | |
output_path = None | |
try: | |
input_path = await save_upload_file(file, prefix="volume_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
audio = AudioSegment.from_file(input_path) | |
adjusted_audio = audio + change_db | |
logger.info(f"Volume adjusted by {change_db}dB.") | |
original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3" | |
if not original_format or original_format == "tmp": original_format = "mp3" | |
output_filename_final = f"volume_{change_db}dB_{file.filename}" | |
output_path = os.path.join(TEMP_DIR, f"volume_out_{uuid.uuid4().hex}.{original_format}") | |
adjusted_audio.export(output_path, format=original_format) | |
background_tasks.add_task(cleanup_file, output_path) | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{original_format}", | |
filename=output_filename_final | |
) | |
except CouldntDecodeError: | |
logger.warning(f"pydub failed to decode: {file.filename}") | |
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.") | |
except Exception as e: | |
logger.error(f"Error during volume operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
if input_path: cleanup_file(input_path) | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}") | |
async def convert_format( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Audio file to convert."), | |
output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').") | |
): | |
"""Converts an audio file to a different format.""" | |
allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a'} | |
if output_format.lower() not in allowed_formats: | |
raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed: {', '.join(allowed_formats)}") | |
logger.info(f"Convert request: file='{file.filename}', output_format='{output_format}'") | |
input_path = None | |
output_path = None | |
try: | |
input_path = await save_upload_file(file, prefix="convert_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
audio = AudioSegment.from_file(input_path) | |
output_format_lower = output_format.lower() | |
filename_base = os.path.splitext(file.filename)[0] | |
output_filename_final = f"{filename_base}_converted.{output_format_lower}" | |
output_path = os.path.join(TEMP_DIR, f"convert_out_{uuid.uuid4().hex}.{output_format_lower}") | |
audio.export(output_path, format=output_format_lower) | |
background_tasks.add_task(cleanup_file, output_path) | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{output_format_lower}", | |
filename=output_filename_final | |
) | |
except CouldntDecodeError: | |
logger.warning(f"pydub failed to decode: {file.filename}") | |
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.") | |
except Exception as e: | |
logger.error(f"Error during convert operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
if input_path: cleanup_file(input_path) | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}") | |
# (Include /enhance and /separate AI endpoints here - same as previous version) | |
# ... | |
async def enhance_speech( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."), | |
model_key: str = Query("speechbrain_enhancer", description="Internal key of the enhancement model to use."), | |
output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).") | |
): | |
"""Enhances speech audio using a pre-loaded AI model (experimental).""" | |
if not AI_LIBRARIES_AVAILABLE: | |
raise HTTPException(status_code=501, detail="AI processing libraries not available.") | |
if model_key not in enhancement_models or enhancement_models[model_key] is None: | |
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.") | |
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.") | |
logger.info(f"Enhance request: file='{file.filename}', model_key='{model_key}', format='{output_format}'") | |
input_path = None | |
output_path = None | |
try: | |
input_path = await save_upload_file(file, prefix="enhance_in_") | |
background_tasks.add_task(cleanup_file, input_path) | |
# Load audio, ensure correct SR for the model | |
logger.debug(f"Loading audio for enhancement, target SR: {ENHANCEMENT_SR}") | |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR) | |
if current_sr != ENHANCEMENT_SR: # Should have been resampled, but double check | |
logger.warning(f"Audio SR after loading is {current_sr}, expected {ENHANCEMENT_SR}. Check resampling.") | |
# Depending on model strictness, could raise error or proceed cautiously. | |
# raise HTTPException(status_code=500, detail="Audio resampling failed.") | |
# Run inference in a separate thread | |
logger.info("Submitting enhancement task to background thread...") | |
enhanced_audio = await asyncio.to_thread( | |
_run_enhancement_sync, model_key, audio_data, current_sr # Pass key, data, and ACTUAL sr used | |
) | |
logger.info("Enhancement task completed.") | |
# Save the result | |
output_path = save_hf_audio(enhanced_audio, ENHANCEMENT_SR, output_format) # Save with model's target SR | |
background_tasks.add_task(cleanup_file, output_path) | |
output_filename_final = f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}" | |
return FileResponse( | |
path=output_path, | |
media_type=f"audio/{output_format}", | |
filename=output_filename_final | |
) | |
except Exception as e: | |
logger.error(f"Error during enhancement operation: {e}", exc_info=True) | |
if output_path: cleanup_file(output_path) | |
if input_path: cleanup_file(input_path) | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}") | |
async def separate_sources( | |
background_tasks: BackgroundTasks, | |
file: UploadFile = File(..., description="Music audio file to separate into stems."), | |
model_key: str = Query("demucs_separator", description="Internal key of the separation model to use."), | |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."), | |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).") | |
): | |
"""Separates music into stems (vocals, drums, bass, other) using a pre-loaded AI model (experimental). Returns a ZIP archive.""" | |
if not AI_LIBRARIES_AVAILABLE: | |
raise HTTPException(status_code=501, detail="AI processing libraries not available.") | |
if model_key not in separation_models or separation_models[model_key] is None: | |
logger.error(f"Separation model key '{model_key}' requested but model not loaded.") | |
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.") | |
valid_stems = {'vocals', 'drums', 'bass', 'other'} # Based on typical Demucs output | |
requested_stems = set(s.lower() for s in stems) | |
if not requested_stems.issubset(valid_stems): | |
# Allow if all stems are requested even if validation set is smaller? Or just error. | |
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are generally: {', '.join(valid_stems)}") | |
logger.info(f"Separate request: file='{file.filename}', model_key='{model_key}', stems={requested_stems}, format='{output_format}'") | |
input_path = None | |
stem_output_paths: Dict[str, str] = {} | |
zip_buffer = io.BytesIO() # Use BytesIO for in-memory ZIP | |
try: | |
input_path = await save_upload_file(file, prefix="separate_in_") | |
background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup | |
# Load audio, ensure correct SR for the model | |
logger.debug(f"Loading audio for separation, target SR: {DEMUCS_SR}") | |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR) | |
if current_sr != DEMUCS_SR: | |
logger.warning(f"Audio SR after loading is {current_sr}, expected {DEMUCS_SR}. Check resampling.") | |
# raise HTTPException(status_code=500, detail="Audio resampling failed.") | |
# Run inference in a separate thread | |
logger.info("Submitting separation task to background thread...") | |
all_separated_stems = await asyncio.to_thread( | |
_run_separation_sync, model_key, audio_data, current_sr # Pass key, data, actual SR | |
) | |
logger.info("Separation task completed.") | |
# --- Create ZIP file in memory --- | |
zip_filename_base = f"separated_{os.path.splitext(file.filename)[0]}" | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
logger.info(f"Creating ZIP archive in memory...") | |
found_stems_count = 0 | |
for stem_name in requested_stems: | |
if stem_name in all_separated_stems: | |
stem_data = all_separated_stems[stem_name] | |
if stem_data is None or stem_data.size == 0: | |
logger.warning(f"Stem '{stem_name}' data is empty, skipping.") | |
continue | |
# Save stem temporarily to disk first (needed for pydub/sf.write) | |
logger.debug(f"Saving temporary stem file for '{stem_name}'...") | |
stem_path = save_hf_audio(stem_data, DEMUCS_SR, output_format) # Save with model's target SR | |
stem_output_paths[stem_name] = stem_path # Keep track for cleanup | |
background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup | |
# Add the saved stem file to the ZIP archive | |
archive_name = f"{stem_name}.{output_format}" # Simple name inside zip | |
zipf.write(stem_path, arcname=archive_name) | |
logger.info(f"Added '{archive_name}' to ZIP.") | |
found_stems_count += 1 | |
else: | |
logger.warning(f"Requested stem '{stem_name}' not found in model output keys: {list(all_separated_stems.keys())}") | |
if found_stems_count == 0: | |
raise HTTPException(status_code=404, detail="None of the requested stems were found or generated successfully.") | |
zip_buffer.seek(0) # Rewind buffer pointer | |
# Return the ZIP file via StreamingResponse | |
zip_filename_download = f"{zip_filename_base}.zip" | |
logger.info(f"Sending ZIP file '{zip_filename_download}'") | |
return StreamingResponse( | |
zip_buffer, # Pass the BytesIO buffer directly | |
media_type="application/zip", | |
headers={'Content-Disposition': f'attachment; filename="{zip_filename_download}"'} | |
) | |
except Exception as e: | |
logger.error(f"Error during separation operation: {e}", exc_info=True) | |
# Cleanup temporary stem files if they exist | |
for path in stem_output_paths.values(): cleanup_file(path) | |
# Close buffer just in case (though StreamingResponse should handle it) | |
# if zip_buffer and not zip_buffer.closed: zip_buffer.close() | |
if input_path: cleanup_file(input_path) | |
if isinstance(e, HTTPException): raise e | |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}") | |
# ----------- END app.py ----------- |