Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
import tempfile
|
4 |
+
import logging
|
5 |
+
import asyncio
|
6 |
+
from typing import List, Optional, Dict, Any
|
7 |
+
|
8 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
|
9 |
+
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
10 |
+
import io # For zip file in memory
|
11 |
+
import zipfile
|
12 |
+
|
13 |
+
# --- Basic Editing Imports ---
|
14 |
+
from pydub import AudioSegment
|
15 |
+
from pydub.exceptions import CouldntDecodeError
|
16 |
+
|
17 |
+
# --- AI & Advanced Audio Imports ---
|
18 |
+
try:
|
19 |
+
import torch
|
20 |
+
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible
|
21 |
+
# Specific model imports might be needed depending on the chosen approach
|
22 |
+
import soundfile as sf
|
23 |
+
import numpy as np
|
24 |
+
import librosa # For resampling if needed
|
25 |
+
print("AI and advanced audio libraries loaded.")
|
26 |
+
except ImportError as e:
|
27 |
+
print(f"Error importing AI/Audio libraries: {e}")
|
28 |
+
print("Ensure torch, transformers, soundfile, librosa are installed.")
|
29 |
+
print("AI features will be unavailable.")
|
30 |
+
torch = None
|
31 |
+
pipeline = None
|
32 |
+
sf = None
|
33 |
+
np = None
|
34 |
+
librosa = None
|
35 |
+
|
36 |
+
# --- Configuration & Setup ---
|
37 |
+
TEMP_DIR = tempfile.gettempdir()
|
38 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
39 |
+
|
40 |
+
# Configure logging
|
41 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
# --- Global Variables for Loaded Models ---
|
45 |
+
# Use dictionaries to potentially hold multiple models of each type later
|
46 |
+
enhancement_pipelines: Dict[str, Any] = {}
|
47 |
+
separation_models: Dict[str, Any] = {} # Might store pipeline or model/processor pair
|
48 |
+
|
49 |
+
# Target sampling rates for models (check model cards on Hugging Face!)
|
50 |
+
ENHANCEMENT_SR = 16000 # Example for speechbrain/sepformer
|
51 |
+
DEMUCS_SR = 44100 # Demucs default
|
52 |
+
|
53 |
+
# --- Helper Functions ---
|
54 |
+
|
55 |
+
def cleanup_file(file_path: str):
|
56 |
+
"""Safely remove a file."""
|
57 |
+
try:
|
58 |
+
if file_path and os.path.exists(file_path):
|
59 |
+
os.remove(file_path)
|
60 |
+
logger.info(f"Cleaned up temporary file: {file_path}")
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
|
63 |
+
|
64 |
+
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
65 |
+
"""Saves an uploaded file to a temporary location and returns the path."""
|
66 |
+
# Generate a unique temporary file path
|
67 |
+
_, file_extension = os.path.splitext(upload_file.filename)
|
68 |
+
if not file_extension: file_extension = ".wav" # Default if no extension
|
69 |
+
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
70 |
+
try:
|
71 |
+
with open(temp_file_path, "wb") as buffer:
|
72 |
+
while content := await upload_file.read(1024 * 1024): buffer.write(content)
|
73 |
+
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
|
74 |
+
return temp_file_path
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
|
77 |
+
cleanup_file(temp_file_path)
|
78 |
+
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
|
79 |
+
finally:
|
80 |
+
await upload_file.close()
|
81 |
+
|
82 |
+
# --- Audio Loading/Saving for AI Models ---
|
83 |
+
|
84 |
+
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]:
|
85 |
+
"""Loads audio using soundfile, converts to mono float32, optionally resamples."""
|
86 |
+
try:
|
87 |
+
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
88 |
+
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
|
89 |
+
|
90 |
+
# Convert to mono if stereo
|
91 |
+
if audio.ndim > 1 and audio.shape[1] > 1:
|
92 |
+
# Simple averaging for mono conversion
|
93 |
+
audio = np.mean(audio, axis=1)
|
94 |
+
logger.info("Converted audio to mono")
|
95 |
+
|
96 |
+
# Resample if necessary
|
97 |
+
if target_sr and orig_sr != target_sr:
|
98 |
+
if librosa is None:
|
99 |
+
raise RuntimeError("Librosa is required for resampling but not installed.")
|
100 |
+
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
|
101 |
+
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
102 |
+
logger.info(f"Resampled audio shape: {audio.shape}")
|
103 |
+
current_sr = target_sr
|
104 |
+
else:
|
105 |
+
current_sr = orig_sr
|
106 |
+
|
107 |
+
return audio, current_sr
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
|
111 |
+
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")
|
112 |
+
|
113 |
+
def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str:
|
114 |
+
"""Saves a NumPy audio array to a temporary file."""
|
115 |
+
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
|
116 |
+
output_path = os.path.join(TEMP_DIR, output_filename)
|
117 |
+
try:
|
118 |
+
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
|
119 |
+
# Ensure data is float32 for common formats like wav/flac, pydub handles mp3 etc.
|
120 |
+
if audio_data.dtype != np.float32:
|
121 |
+
audio_data = audio_data.astype(np.float32)
|
122 |
+
|
123 |
+
# Use soundfile for lossless formats
|
124 |
+
if output_format.lower() in ['wav', 'flac']:
|
125 |
+
sf.write(output_path, audio_data, sampling_rate, format=output_format.upper())
|
126 |
+
else:
|
127 |
+
# For lossy formats like mp3, use pydub after converting numpy array
|
128 |
+
# Convert numpy array [-1.0, 1.0] float32 to pydub segment
|
129 |
+
# Scale to 16-bit integer range for pydub if needed
|
130 |
+
audio_int16 = (audio_data * 32767).astype(np.int16)
|
131 |
+
segment = AudioSegment(
|
132 |
+
audio_int16.tobytes(),
|
133 |
+
frame_rate=sampling_rate,
|
134 |
+
sample_width=audio_int16.dtype.itemsize,
|
135 |
+
channels=1 # Assuming mono output from AI models for now
|
136 |
+
)
|
137 |
+
segment.export(output_path, format=output_format)
|
138 |
+
|
139 |
+
return output_path
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
|
142 |
+
cleanup_file(output_path)
|
143 |
+
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
144 |
+
|
145 |
+
# --- Synchronous AI Inference Functions (to be run in threads) ---
|
146 |
+
|
147 |
+
def _run_enhancement_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray:
|
148 |
+
"""Synchronous wrapper for enhancement model inference."""
|
149 |
+
if not model_pipeline: raise ValueError("Enhancement model not loaded")
|
150 |
+
try:
|
151 |
+
logger.info(f"Running speech enhancement (input shape: {audio_data.shape}, SR: {sampling_rate})...")
|
152 |
+
# Pipeline usage depends heavily on the specific pipeline
|
153 |
+
# Example for a hypothetical 'audio-enhancement' pipeline:
|
154 |
+
result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate})
|
155 |
+
enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output
|
156 |
+
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
157 |
+
return enhanced_audio
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
|
160 |
+
raise # Re-raise to be caught by the async wrapper
|
161 |
+
|
162 |
+
def _run_separation_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]:
|
163 |
+
"""Synchronous wrapper for source separation model inference."""
|
164 |
+
if not model_pipeline: raise ValueError("Separation model not loaded")
|
165 |
+
try:
|
166 |
+
logger.info(f"Running source separation (input shape: {audio_data.shape}, SR: {sampling_rate})...")
|
167 |
+
# Usage depends on the separation model/pipeline
|
168 |
+
# Example for a hypothetical 'audio-source-separation' pipeline:
|
169 |
+
# Note: Actual Demucs might need different handling (e.g., direct model call)
|
170 |
+
# result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate})
|
171 |
+
|
172 |
+
# Manual example closer to raw Demucs model (if not using pipeline)
|
173 |
+
# Assuming `separation_models['demucs']` holds the loaded Demucs model instance
|
174 |
+
model = separation_models.get('demucs')
|
175 |
+
if not model: raise ValueError("Demucs model not loaded correctly")
|
176 |
+
|
177 |
+
# Demucs expects stereo input in shape (batch, channels, samples)
|
178 |
+
# Convert mono to stereo if needed, add batch dim
|
179 |
+
if audio_data.ndim == 1:
|
180 |
+
audio_data = np.stack([audio_data, audio_data], axis=0) # Create stereo from mono
|
181 |
+
audio_tensor = torch.tensor(audio_data).unsqueeze(0) # Add batch dimension
|
182 |
+
|
183 |
+
# Move to GPU if available and model is on GPU
|
184 |
+
device = next(model.parameters()).device
|
185 |
+
audio_tensor = audio_tensor.to(device)
|
186 |
+
|
187 |
+
with torch.no_grad():
|
188 |
+
sources = model(audio_tensor)[0] # Output shape (stems, channels, samples)
|
189 |
+
|
190 |
+
# Detach, move to CPU, convert to numpy
|
191 |
+
sources_np = sources.detach().cpu().numpy()
|
192 |
+
|
193 |
+
# Convert back to mono for simplicity (average channels)
|
194 |
+
stems = {
|
195 |
+
'drums': np.mean(sources_np[0], axis=0),
|
196 |
+
'bass': np.mean(sources_np[1], axis=0),
|
197 |
+
'other': np.mean(sources_np[2], axis=0),
|
198 |
+
'vocals': np.mean(sources_np[3], axis=0),
|
199 |
+
}
|
200 |
+
# Important: The order (drums, bass, other, vocals) is specific to Demucs v3/v4 default model
|
201 |
+
logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
|
202 |
+
return stems
|
203 |
+
|
204 |
+
except Exception as e:
|
205 |
+
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
|
206 |
+
raise
|
207 |
+
|
208 |
+
# --- Model Loading Function ---
|
209 |
+
def load_hf_models():
|
210 |
+
"""Loads Hugging Face models at startup."""
|
211 |
+
global enhancement_pipelines, separation_models
|
212 |
+
if torch is None or pipeline is None:
|
213 |
+
logger.warning("Torch or Transformers not available. Skipping Hugging Face model loading.")
|
214 |
+
return
|
215 |
+
|
216 |
+
# --- Load Enhancement Model ---
|
217 |
+
# Using speechbrain/sepformer-whamr-enhancement via pipeline (check HF for exact pipeline task)
|
218 |
+
# Or load model/processor manually if no direct pipeline exists
|
219 |
+
enhancement_model_id = "speechbrain/sepformer-whamr-enhancement" # Example ID
|
220 |
+
try:
|
221 |
+
logger.info(f"Loading enhancement model: {enhancement_model_id}...")
|
222 |
+
# Use appropriate task, might be 'audio-enhancement', 'audio-classification' with custom logic, or manual loading
|
223 |
+
# If no pipeline, load manually:
|
224 |
+
# enhancement_processor = AutoProcessor.from_pretrained(...)
|
225 |
+
# enhancement_model = AutoModel...from_pretrained(...)
|
226 |
+
# enhancement_pipelines['speechbrain_sepformer'] = {"processor": enhancement_processor, "model": enhancement_model}
|
227 |
+
|
228 |
+
# For now, let's assume a placeholder pipeline exists or skip if complex
|
229 |
+
# enhancement_pipelines['speechbrain_sepformer'] = pipeline("audio-enhancement", model=enhancement_model_id)
|
230 |
+
logger.warning(f"Skipping load for {enhancement_model_id} - requires specific pipeline or manual setup.")
|
231 |
+
|
232 |
+
except Exception as e:
|
233 |
+
logger.error(f"Failed to load enhancement model '{enhancement_model_id}': {e}", exc_info=False)
|
234 |
+
|
235 |
+
|
236 |
+
# --- Load Separation Model (Demucs) ---
|
237 |
+
# Demucs is often used directly, not via a standard HF pipeline task
|
238 |
+
separation_model_id = "facebook/demucs" # Or specific variant like facebook/hybrid_demucs
|
239 |
+
try:
|
240 |
+
logger.info(f"Loading separation model: {separation_model_id}...")
|
241 |
+
# Demucs usually requires loading the model directly
|
242 |
+
# Using AutoModel might work for some variants if configured correctly in HF hub
|
243 |
+
# separation_models['demucs'] = AutoModel.from_pretrained(separation_model_id) # Check if this works
|
244 |
+
|
245 |
+
# More typically, you might need to install the 'demucs' package itself: pip install -U demucs
|
246 |
+
# import demucs.separate
|
247 |
+
# model = demucs.apply.load_model(separation_model_id or demucs.pretrained.DEFAULT_MODEL) # Using demucs package
|
248 |
+
# separation_models['demucs'] = model
|
249 |
+
|
250 |
+
# For now, simulate loading failure as direct AutoModel might not work
|
251 |
+
raise NotImplementedError("Demucs loading typically requires the 'demucs' package or specific manual loading.")
|
252 |
+
logger.info(f"Separation model '{separation_model_id}' loaded.")
|
253 |
+
|
254 |
+
except Exception as e:
|
255 |
+
logger.error(f"Failed to load separation model '{separation_model_id}': {e}", exc_info=False)
|
256 |
+
logger.warning("Note: Demucs loading often requires 'pip install demucs' and specific loading code, not just AutoModel.")
|
257 |
+
|
258 |
+
|
259 |
+
# --- FastAPI App and Endpoints ---
|
260 |
+
app = FastAPI(
|
261 |
+
title="AI Audio Editor API",
|
262 |
+
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and HF model dependencies.",
|
263 |
+
version="2.0.0",
|
264 |
+
)
|
265 |
+
|
266 |
+
@app.on_event("startup")
|
267 |
+
async def startup_event():
|
268 |
+
"""Load models when the application starts."""
|
269 |
+
logger.info("Application startup: Loading AI models...")
|
270 |
+
# Running model loading in a separate thread to avoid blocking startup completely
|
271 |
+
# although startup will wait for this thread to finish.
|
272 |
+
# Consider truly background loading if startup time is critical.
|
273 |
+
await asyncio.to_thread(load_hf_models)
|
274 |
+
logger.info("Model loading process finished (check logs for success/failure).")
|
275 |
+
|
276 |
+
|
277 |
+
# --- Basic Editing Endpoints (Mostly Unchanged) ---
|
278 |
+
|
279 |
+
@app.get("/", tags=["General"])
|
280 |
+
def read_root():
|
281 |
+
"""Root endpoint providing a welcome message and available features."""
|
282 |
+
features = ["/trim", "/concat", "/volume", "/convert"]
|
283 |
+
ai_features = []
|
284 |
+
if enhancement_pipelines: ai_features.append("/enhance")
|
285 |
+
if separation_models: ai_features.append("/separate")
|
286 |
+
|
287 |
+
return {
|
288 |
+
"message": "Welcome to the AI Audio Editor API.",
|
289 |
+
"basic_features": features,
|
290 |
+
"ai_features": ai_features if ai_features else "None available (models might have failed to load)",
|
291 |
+
"notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)."
|
292 |
+
}
|
293 |
+
|
294 |
+
# /trim, /concat, /volume, /convert endpoints remain largely the same as before
|
295 |
+
# Ensure they use the updated save_upload_file and cleanup logic
|
296 |
+
# (Code for these endpoints omitted for brevity - refer to previous example)
|
297 |
+
# ... Add /trim, /concat, /volume, /convert endpoints here ...
|
298 |
+
|
299 |
+
|
300 |
+
# --- AI Endpoints ---
|
301 |
+
|
302 |
+
@app.post("/enhance", tags=["AI Editing"])
|
303 |
+
async def enhance_speech(
|
304 |
+
background_tasks: BackgroundTasks,
|
305 |
+
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
|
306 |
+
model_id: str = Query("speechbrain_sepformer", description="ID of the enhancement model to use (if multiple loaded)."),
|
307 |
+
output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).")
|
308 |
+
):
|
309 |
+
"""Enhances speech audio using a pre-loaded AI model (experimental)."""
|
310 |
+
if torch is None or sf is None or np is None:
|
311 |
+
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
312 |
+
if model_id not in enhancement_pipelines:
|
313 |
+
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_id}' is not loaded or available.")
|
314 |
+
|
315 |
+
logger.info(f"Enhance request: file='{file.filename}', model='{model_id}', format='{output_format}'")
|
316 |
+
input_path = await save_upload_file(file, prefix="enhance_in_")
|
317 |
+
background_tasks.add_task(cleanup_file, input_path)
|
318 |
+
|
319 |
+
output_path = None # Define output_path before try block
|
320 |
+
try:
|
321 |
+
# Load audio, ensure correct SR for the model
|
322 |
+
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
323 |
+
|
324 |
+
# Run inference in a separate thread
|
325 |
+
logger.info("Submitting enhancement task to background thread...")
|
326 |
+
model_pipeline = enhancement_pipelines[model_id] # Get the specific loaded pipeline/model
|
327 |
+
enhanced_audio = await asyncio.to_thread(
|
328 |
+
_run_enhancement_sync, model_pipeline, audio_data, current_sr
|
329 |
+
)
|
330 |
+
logger.info("Enhancement task completed.")
|
331 |
+
|
332 |
+
# Save the result
|
333 |
+
output_path = save_hf_audio(enhanced_audio, current_sr, output_format) # Use current_sr (which is target_sr)
|
334 |
+
background_tasks.add_task(cleanup_file, output_path)
|
335 |
+
|
336 |
+
return FileResponse(
|
337 |
+
path=output_path,
|
338 |
+
media_type=f"audio/{output_format}",
|
339 |
+
filename=f"enhanced_{file.filename}"
|
340 |
+
)
|
341 |
+
|
342 |
+
except Exception as e:
|
343 |
+
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
|
344 |
+
if output_path: cleanup_file(output_path) # Cleanup if error occurred after output started saving
|
345 |
+
if isinstance(e, HTTPException): raise e
|
346 |
+
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
|
347 |
+
|
348 |
+
|
349 |
+
@app.post("/separate", tags=["AI Editing"])
|
350 |
+
async def separate_sources(
|
351 |
+
background_tasks: BackgroundTasks,
|
352 |
+
file: UploadFile = File(..., description="Music audio file to separate into stems."),
|
353 |
+
model_id: str = Query("demucs", description="ID of the separation model to use."),
|
354 |
+
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
|
355 |
+
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
356 |
+
):
|
357 |
+
"""Separates music into stems (vocals, drums, bass, other) using Demucs (experimental). Returns a ZIP archive."""
|
358 |
+
if torch is None or sf is None or np is None:
|
359 |
+
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
360 |
+
if model_id not in separation_models:
|
361 |
+
raise HTTPException(status_code=503, detail=f"Separation model '{model_id}' is not loaded or available.")
|
362 |
+
|
363 |
+
valid_stems = {'vocals', 'drums', 'bass', 'other'}
|
364 |
+
requested_stems = set(s.lower() for s in stems)
|
365 |
+
if not requested_stems.issubset(valid_stems):
|
366 |
+
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are: {', '.join(valid_stems)}")
|
367 |
+
|
368 |
+
logger.info(f"Separate request: file='{file.filename}', model='{model_id}', stems={requested_stems}, format='{output_format}'")
|
369 |
+
input_path = await save_upload_file(file, prefix="separate_in_")
|
370 |
+
background_tasks.add_task(cleanup_file, input_path)
|
371 |
+
|
372 |
+
stem_output_paths: Dict[str, str] = {}
|
373 |
+
zip_buffer = None
|
374 |
+
try:
|
375 |
+
# Load audio, ensure correct SR for the model (Demucs uses 44.1kHz)
|
376 |
+
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
377 |
+
|
378 |
+
# Run inference in a separate thread
|
379 |
+
logger.info("Submitting separation task to background thread...")
|
380 |
+
model = separation_models[model_id] # Get the specific loaded model
|
381 |
+
all_separated_stems = await asyncio.to_thread(
|
382 |
+
_run_separation_sync, model, audio_data, current_sr
|
383 |
+
)
|
384 |
+
logger.info("Separation task completed.")
|
385 |
+
|
386 |
+
# --- Create ZIP file in memory ---
|
387 |
+
zip_buffer = io.BytesIO()
|
388 |
+
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
389 |
+
# Save only the requested stems
|
390 |
+
for stem_name in requested_stems:
|
391 |
+
if stem_name in all_separated_stems:
|
392 |
+
stem_data = all_separated_stems[stem_name]
|
393 |
+
# Save stem temporarily to disk first (needed for pydub/sf.write)
|
394 |
+
stem_path = save_hf_audio(stem_data, current_sr, output_format)
|
395 |
+
stem_output_paths[stem_name] = stem_path # Keep track for cleanup
|
396 |
+
background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
|
397 |
+
|
398 |
+
# Add the saved stem file to the ZIP archive
|
399 |
+
archive_name = f"{stem_name}_{os.path.basename(input_path)}.{output_format}"
|
400 |
+
zipf.write(stem_path, arcname=archive_name)
|
401 |
+
logger.info(f"Added '{archive_name}' to ZIP.")
|
402 |
+
else:
|
403 |
+
logger.warning(f"Requested stem '{stem_name}' not found in model output.")
|
404 |
+
|
405 |
+
zip_buffer.seek(0) # Rewind buffer pointer
|
406 |
+
|
407 |
+
# Return the ZIP file
|
408 |
+
zip_filename = f"separated_stems_{os.path.splitext(file.filename)[0]}.zip"
|
409 |
+
return StreamingResponse(
|
410 |
+
zip_buffer,
|
411 |
+
media_type="application/zip",
|
412 |
+
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
|
413 |
+
)
|
414 |
+
|
415 |
+
except Exception as e:
|
416 |
+
logger.error(f"Error during separation operation: {e}", exc_info=True)
|
417 |
+
# Cleanup any stems that were saved before zipping failed
|
418 |
+
for path in stem_output_paths.values():
|
419 |
+
cleanup_file(path)
|
420 |
+
if zip_buffer: zip_buffer.close() # Close memory buffer
|
421 |
+
|
422 |
+
if isinstance(e, HTTPException): raise e
|
423 |
+
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
|
424 |
+
|
425 |
+
|
426 |
+
# --- How to Run ---
|
427 |
+
# 1. Ensure FFmpeg is installed and accessible.
|
428 |
+
# 2. Save this code as `app.py`.
|
429 |
+
# 3. Create `requirements.txt` (as shown above).
|
430 |
+
# 4. Install dependencies: `pip install -r requirements.txt` (This can take time!)
|
431 |
+
# 5. Run the FastAPI server: `uvicorn app:app --reload --host 0.0.0.0`
|
432 |
+
# (Use --host 0.0.0.0 for external/Docker access. --reload is optional)
|
433 |
+
#
|
434 |
+
# --- WARNING ---
|
435 |
+
# - AI models require SIGNIFICANT RAM and CPU/GPU. Inference can be SLOW.
|
436 |
+
# - The first run will download models, which can take a long time and lots of disk space.
|
437 |
+
# - Ensure the specific model IDs used are correct and compatible with HF libraries.
|
438 |
+
# - Model loading at startup might fail if dependencies are missing or resources are insufficient. Check logs!
|
439 |
+
#
|
440 |
+
# --- Example Usage (using curl) ---
|
441 |
+
#
|
442 |
+
# **Enhance:** (Enhance noisy_speech.wav)
|
443 |
+
# curl -X POST "http://127.0.0.1:8000/enhance?model_id=speechbrain_sepformer" \
|
444 |
+
# -F "file=@noisy_speech.wav" \
|
445 |
+
# -F "output_format=wav" \
|
446 |
+
# --output enhanced_speech.wav
|
447 |
+
#
|
448 |
+
# **Separate:** (Separate vocals and drums from music.mp3)
|
449 |
+
# curl -X POST "http://127.0.0.1:8000/separate?model_id=demucs" \
|
450 |
+
# -F "[email protected]" \
|
451 |
+
# -F "stems=vocals" \
|
452 |
+
# -F "stems=drums" \
|
453 |
+
# -F "output_format=mp3" \
|
454 |
+
# --output separated_stems.zip
|