Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,15 @@
|
|
|
|
1 |
import os
|
2 |
import uuid
|
3 |
import tempfile
|
4 |
import logging
|
5 |
import asyncio
|
6 |
from typing import List, Optional, Dict, Any
|
|
|
|
|
7 |
|
8 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
|
9 |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
10 |
-
import io # For zip file in memory
|
11 |
-
import zipfile
|
12 |
|
13 |
# --- Basic Editing Imports ---
|
14 |
from pydub import AudioSegment
|
@@ -19,20 +20,26 @@ try:
|
|
19 |
import torch
|
20 |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible
|
21 |
# Specific model imports might be needed depending on the chosen approach
|
|
|
|
|
22 |
import soundfile as sf
|
23 |
import numpy as np
|
24 |
import librosa # For resampling if needed
|
|
|
25 |
print("AI and advanced audio libraries loaded.")
|
26 |
except ImportError as e:
|
27 |
-
print(f"Error importing AI/Audio libraries: {e}")
|
28 |
print("Ensure torch, transformers, soundfile, librosa are installed.")
|
29 |
print("AI features will be unavailable.")
|
|
|
|
|
30 |
torch = None
|
31 |
pipeline = None
|
32 |
sf = None
|
33 |
np = None
|
34 |
librosa = None
|
35 |
|
|
|
36 |
# --- Configuration & Setup ---
|
37 |
TEMP_DIR = tempfile.gettempdir()
|
38 |
os.makedirs(TEMP_DIR, exist_ok=True)
|
@@ -43,15 +50,28 @@ logger = logging.getLogger(__name__)
|
|
43 |
|
44 |
# --- Global Variables for Loaded Models ---
|
45 |
# Use dictionaries to potentially hold multiple models of each type later
|
46 |
-
|
47 |
-
separation_models: Dict[str, Any] = {} #
|
48 |
|
49 |
# Target sampling rates for models (check model cards on Hugging Face!)
|
50 |
-
|
51 |
-
|
|
|
52 |
|
53 |
-
#
|
|
|
|
|
|
|
|
|
|
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def cleanup_file(file_path: str):
|
56 |
"""Safely remove a file."""
|
57 |
try:
|
@@ -63,7 +83,6 @@ def cleanup_file(file_path: str):
|
|
63 |
|
64 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
65 |
"""Saves an uploaded file to a temporary location and returns the path."""
|
66 |
-
# Generate a unique temporary file path
|
67 |
_, file_extension = os.path.splitext(upload_file.filename)
|
68 |
if not file_extension: file_extension = ".wav" # Default if no extension
|
69 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
@@ -79,26 +98,32 @@ async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") ->
|
|
79 |
finally:
|
80 |
await upload_file.close()
|
81 |
|
82 |
-
# --- Audio Loading/Saving for AI Models ---
|
83 |
-
|
84 |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]:
|
85 |
"""Loads audio using soundfile, converts to mono float32, optionally resamples."""
|
|
|
|
|
86 |
try:
|
87 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
88 |
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
|
89 |
|
90 |
-
#
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
# Resample if necessary
|
97 |
if target_sr and orig_sr != target_sr:
|
98 |
if librosa is None:
|
99 |
raise RuntimeError("Librosa is required for resampling but not installed.")
|
100 |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
|
101 |
-
|
|
|
|
|
|
|
102 |
logger.info(f"Resampled audio shape: {audio.shape}")
|
103 |
current_sr = target_sr
|
104 |
else:
|
@@ -112,22 +137,34 @@ def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[
|
|
112 |
|
113 |
def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str:
|
114 |
"""Saves a NumPy audio array to a temporary file."""
|
|
|
|
|
|
|
115 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
|
116 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
117 |
try:
|
118 |
-
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
|
119 |
-
|
|
|
120 |
if audio_data.dtype != np.float32:
|
|
|
121 |
audio_data = audio_data.astype(np.float32)
|
122 |
|
|
|
|
|
|
|
123 |
# Use soundfile for lossless formats
|
124 |
if output_format.lower() in ['wav', 'flac']:
|
125 |
sf.write(output_path, audio_data, sampling_rate, format=output_format.upper())
|
126 |
else:
|
127 |
# For lossy formats like mp3, use pydub after converting numpy array
|
128 |
-
|
129 |
-
# Scale
|
130 |
audio_int16 = (audio_data * 32767).astype(np.int16)
|
|
|
|
|
|
|
|
|
131 |
segment = AudioSegment(
|
132 |
audio_int16.tobytes(),
|
133 |
frame_rate=sampling_rate,
|
@@ -142,118 +179,191 @@ def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str
|
|
142 |
cleanup_file(output_path)
|
143 |
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
144 |
|
145 |
-
# --- Synchronous AI Inference Functions (
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
"""Synchronous wrapper for enhancement model inference."""
|
149 |
-
if not
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
try:
|
151 |
-
logger.info(f"Running speech enhancement (input shape: {audio_data.shape}, SR: {sampling_rate})...")
|
152 |
-
#
|
153 |
-
#
|
154 |
-
|
|
|
|
|
|
|
|
|
155 |
enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output
|
|
|
156 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
157 |
return enhanced_audio
|
158 |
except Exception as e:
|
159 |
-
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
|
160 |
raise # Re-raise to be caught by the async wrapper
|
161 |
|
162 |
-
|
|
|
163 |
"""Synchronous wrapper for source separation model inference."""
|
164 |
-
if not
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
try:
|
166 |
-
logger.info(f"Running source separation (input shape: {audio_data.shape}, SR: {sampling_rate})...")
|
167 |
-
|
168 |
-
#
|
169 |
-
#
|
170 |
-
# result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate})
|
171 |
-
|
172 |
-
# Manual example closer to raw Demucs model (if not using pipeline)
|
173 |
-
# Assuming `separation_models['demucs']` holds the loaded Demucs model instance
|
174 |
-
model = separation_models.get('demucs')
|
175 |
-
if not model: raise ValueError("Demucs model not loaded correctly")
|
176 |
-
|
177 |
-
# Demucs expects stereo input in shape (batch, channels, samples)
|
178 |
-
# Convert mono to stereo if needed, add batch dim
|
179 |
if audio_data.ndim == 1:
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
-
# Move to
|
184 |
device = next(model.parameters()).device
|
|
|
185 |
audio_tensor = audio_tensor.to(device)
|
186 |
|
|
|
187 |
with torch.no_grad():
|
188 |
-
|
|
|
|
|
|
|
189 |
|
190 |
# Detach, move to CPU, convert to numpy
|
191 |
-
sources_np = sources.detach().cpu().numpy()
|
192 |
-
|
193 |
-
#
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
|
202 |
return stems
|
203 |
|
204 |
except Exception as e:
|
205 |
-
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
|
206 |
raise
|
207 |
|
208 |
# --- Model Loading Function ---
|
|
|
|
|
|
|
|
|
209 |
def load_hf_models():
|
210 |
"""Loads Hugging Face models at startup."""
|
211 |
-
|
212 |
-
|
213 |
-
logger.warning("Torch or Transformers not available. Skipping Hugging Face model loading.")
|
214 |
return
|
215 |
|
|
|
|
|
216 |
# --- Load Enhancement Model ---
|
217 |
-
|
218 |
-
# Or load model/processor manually if no direct pipeline exists
|
219 |
-
enhancement_model_id = "speechbrain/sepformer-whamr-enhancement" # Example ID
|
220 |
try:
|
221 |
-
logger.info(f"
|
222 |
-
#
|
223 |
-
#
|
224 |
-
#
|
225 |
-
#
|
226 |
-
#
|
227 |
-
|
228 |
-
#
|
229 |
-
#
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
except Exception as e:
|
233 |
-
logger.error(f"Failed to load enhancement model '{
|
234 |
|
235 |
|
236 |
# --- Load Separation Model (Demucs) ---
|
237 |
-
|
238 |
-
separation_model_id = "facebook/demucs" # Or specific variant like facebook/hybrid_demucs
|
239 |
try:
|
240 |
-
logger.info(f"
|
241 |
-
# Demucs
|
242 |
-
#
|
243 |
-
#
|
244 |
-
|
245 |
-
|
246 |
-
#
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
-
# For now, simulate loading failure as direct AutoModel might not work
|
251 |
-
raise NotImplementedError("Demucs loading typically requires the 'demucs' package or specific manual loading.")
|
252 |
-
logger.info(f"Separation model '{separation_model_id}' loaded.")
|
253 |
|
254 |
except Exception as e:
|
255 |
-
logger.error(f"
|
256 |
-
|
257 |
|
258 |
|
259 |
# --- FastAPI App and Endpoints ---
|
@@ -266,82 +376,267 @@ app = FastAPI(
|
|
266 |
@app.on_event("startup")
|
267 |
async def startup_event():
|
268 |
"""Load models when the application starts."""
|
269 |
-
logger.info("Application startup: Loading AI models...")
|
270 |
-
# Running model loading in a separate thread to avoid blocking startup completely
|
271 |
-
# although startup will wait for this thread to finish.
|
272 |
-
# Consider truly background loading if startup time is critical.
|
273 |
await asyncio.to_thread(load_hf_models)
|
274 |
-
logger.info("Model loading process finished
|
275 |
-
|
276 |
|
277 |
-
# --- Basic Editing Endpoints (Mostly Unchanged) ---
|
278 |
|
|
|
|
|
|
|
279 |
@app.get("/", tags=["General"])
|
280 |
def read_root():
|
281 |
"""Root endpoint providing a welcome message and available features."""
|
282 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
283 |
ai_features = []
|
284 |
-
if
|
285 |
-
if
|
|
|
286 |
|
287 |
return {
|
288 |
"message": "Welcome to the AI Audio Editor API.",
|
289 |
"basic_features": features,
|
290 |
-
"ai_features": ai_features if ai_features else "None
|
291 |
"notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)."
|
292 |
}
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
-
|
|
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
@app.post("/enhance", tags=["AI Editing"])
|
303 |
async def enhance_speech(
|
304 |
background_tasks: BackgroundTasks,
|
305 |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
|
306 |
-
|
307 |
output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).")
|
308 |
):
|
309 |
"""Enhances speech audio using a pre-loaded AI model (experimental)."""
|
310 |
-
if
|
311 |
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
312 |
-
if
|
313 |
-
|
|
|
314 |
|
315 |
-
logger.info(f"Enhance request: file='{file.filename}',
|
316 |
-
input_path =
|
317 |
-
|
318 |
-
|
319 |
-
output_path = None # Define output_path before try block
|
320 |
try:
|
|
|
|
|
|
|
321 |
# Load audio, ensure correct SR for the model
|
|
|
322 |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
|
|
|
|
|
|
|
|
323 |
|
324 |
# Run inference in a separate thread
|
325 |
logger.info("Submitting enhancement task to background thread...")
|
326 |
-
model_pipeline = enhancement_pipelines[model_id] # Get the specific loaded pipeline/model
|
327 |
enhanced_audio = await asyncio.to_thread(
|
328 |
-
_run_enhancement_sync,
|
329 |
)
|
330 |
logger.info("Enhancement task completed.")
|
331 |
|
332 |
# Save the result
|
333 |
-
output_path = save_hf_audio(enhanced_audio,
|
334 |
background_tasks.add_task(cleanup_file, output_path)
|
335 |
|
|
|
336 |
return FileResponse(
|
337 |
path=output_path,
|
338 |
media_type=f"audio/{output_format}",
|
339 |
-
filename=
|
340 |
)
|
341 |
|
342 |
except Exception as e:
|
343 |
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
|
344 |
-
if output_path: cleanup_file(output_path)
|
|
|
345 |
if isinstance(e, HTTPException): raise e
|
346 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
|
347 |
|
@@ -350,105 +645,96 @@ async def enhance_speech(
|
|
350 |
async def separate_sources(
|
351 |
background_tasks: BackgroundTasks,
|
352 |
file: UploadFile = File(..., description="Music audio file to separate into stems."),
|
353 |
-
|
354 |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
|
355 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
356 |
):
|
357 |
-
"""Separates music into stems (vocals, drums, bass, other) using
|
358 |
-
if
|
359 |
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
360 |
-
if
|
361 |
-
|
|
|
362 |
|
363 |
-
valid_stems = {'vocals', 'drums', 'bass', 'other'}
|
364 |
requested_stems = set(s.lower() for s in stems)
|
365 |
if not requested_stems.issubset(valid_stems):
|
366 |
-
|
367 |
-
|
368 |
-
logger.info(f"Separate request: file='{file.filename}', model='{model_id}', stems={requested_stems}, format='{output_format}'")
|
369 |
-
input_path = await save_upload_file(file, prefix="separate_in_")
|
370 |
-
background_tasks.add_task(cleanup_file, input_path)
|
371 |
|
|
|
|
|
372 |
stem_output_paths: Dict[str, str] = {}
|
373 |
-
zip_buffer =
|
|
|
374 |
try:
|
375 |
-
|
|
|
|
|
|
|
|
|
376 |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
|
|
|
|
|
|
377 |
|
378 |
# Run inference in a separate thread
|
379 |
logger.info("Submitting separation task to background thread...")
|
380 |
-
model = separation_models[model_id] # Get the specific loaded model
|
381 |
all_separated_stems = await asyncio.to_thread(
|
382 |
-
_run_separation_sync,
|
383 |
)
|
384 |
logger.info("Separation task completed.")
|
385 |
|
386 |
# --- Create ZIP file in memory ---
|
387 |
-
|
388 |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
389 |
-
|
|
|
390 |
for stem_name in requested_stems:
|
391 |
if stem_name in all_separated_stems:
|
392 |
stem_data = all_separated_stems[stem_name]
|
|
|
|
|
|
|
|
|
393 |
# Save stem temporarily to disk first (needed for pydub/sf.write)
|
394 |
-
|
|
|
395 |
stem_output_paths[stem_name] = stem_path # Keep track for cleanup
|
396 |
background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
|
397 |
|
398 |
# Add the saved stem file to the ZIP archive
|
399 |
-
archive_name = f"{stem_name}
|
400 |
zipf.write(stem_path, arcname=archive_name)
|
401 |
logger.info(f"Added '{archive_name}' to ZIP.")
|
|
|
402 |
else:
|
403 |
-
logger.warning(f"Requested stem '{stem_name}' not found in model output.")
|
|
|
|
|
|
|
404 |
|
405 |
zip_buffer.seek(0) # Rewind buffer pointer
|
406 |
|
407 |
-
# Return the ZIP file
|
408 |
-
|
|
|
409 |
return StreamingResponse(
|
410 |
-
zip_buffer,
|
411 |
media_type="application/zip",
|
412 |
-
headers={'Content-Disposition': f'attachment; filename="{
|
413 |
)
|
414 |
|
415 |
except Exception as e:
|
416 |
logger.error(f"Error during separation operation: {e}", exc_info=True)
|
417 |
-
# Cleanup
|
418 |
-
for path in stem_output_paths.values():
|
419 |
-
|
420 |
-
if zip_buffer: zip_buffer.close()
|
|
|
421 |
|
422 |
if isinstance(e, HTTPException): raise e
|
423 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
|
424 |
|
425 |
|
426 |
-
#
|
427 |
-
# 1. Ensure FFmpeg is installed and accessible.
|
428 |
-
# 2. Save this code as `app.py`.
|
429 |
-
# 3. Create `requirements.txt` (as shown above).
|
430 |
-
# 4. Install dependencies: `pip install -r requirements.txt` (This can take time!)
|
431 |
-
# 5. Run the FastAPI server: `uvicorn app:app --reload --host 0.0.0.0`
|
432 |
-
# (Use --host 0.0.0.0 for external/Docker access. --reload is optional)
|
433 |
-
#
|
434 |
-
# --- WARNING ---
|
435 |
-
# - AI models require SIGNIFICANT RAM and CPU/GPU. Inference can be SLOW.
|
436 |
-
# - The first run will download models, which can take a long time and lots of disk space.
|
437 |
-
# - Ensure the specific model IDs used are correct and compatible with HF libraries.
|
438 |
-
# - Model loading at startup might fail if dependencies are missing or resources are insufficient. Check logs!
|
439 |
-
#
|
440 |
-
# --- Example Usage (using curl) ---
|
441 |
-
#
|
442 |
-
# **Enhance:** (Enhance noisy_speech.wav)
|
443 |
-
# curl -X POST "http://127.0.0.1:8000/enhance?model_id=speechbrain_sepformer" \
|
444 |
-
# -F "file=@noisy_speech.wav" \
|
445 |
-
# -F "output_format=wav" \
|
446 |
-
# --output enhanced_speech.wav
|
447 |
-
#
|
448 |
-
# **Separate:** (Separate vocals and drums from music.mp3)
|
449 |
-
# curl -X POST "http://127.0.0.1:8000/separate?model_id=demucs" \
|
450 |
-
# -F "[email protected]" \
|
451 |
-
# -F "stems=vocals" \
|
452 |
-
# -F "stems=drums" \
|
453 |
-
# -F "output_format=mp3" \
|
454 |
-
# --output separated_stems.zip
|
|
|
1 |
+
# ----------- START app.py -----------
|
2 |
import os
|
3 |
import uuid
|
4 |
import tempfile
|
5 |
import logging
|
6 |
import asyncio
|
7 |
from typing import List, Optional, Dict, Any
|
8 |
+
import io
|
9 |
+
import zipfile
|
10 |
|
11 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
|
12 |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
|
|
|
|
13 |
|
14 |
# --- Basic Editing Imports ---
|
15 |
from pydub import AudioSegment
|
|
|
20 |
import torch
|
21 |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible
|
22 |
# Specific model imports might be needed depending on the chosen approach
|
23 |
+
# E.g. for Demucs V4 (Hybrid Transformer): from demucs.hdemucs import HDemucs
|
24 |
+
# from demucs.pretrained import hdemucs_mmi
|
25 |
import soundfile as sf
|
26 |
import numpy as np
|
27 |
import librosa # For resampling if needed
|
28 |
+
AI_LIBRARIES_AVAILABLE = True
|
29 |
print("AI and advanced audio libraries loaded.")
|
30 |
except ImportError as e:
|
31 |
+
print(f"Warning: Error importing AI/Audio libraries: {e}")
|
32 |
print("Ensure torch, transformers, soundfile, librosa are installed.")
|
33 |
print("AI features will be unavailable.")
|
34 |
+
AI_LIBRARIES_AVAILABLE = False
|
35 |
+
# Define dummy placeholders if needed, or just rely on AI_LIBRARIES_AVAILABLE flag
|
36 |
torch = None
|
37 |
pipeline = None
|
38 |
sf = None
|
39 |
np = None
|
40 |
librosa = None
|
41 |
|
42 |
+
|
43 |
# --- Configuration & Setup ---
|
44 |
TEMP_DIR = tempfile.gettempdir()
|
45 |
os.makedirs(TEMP_DIR, exist_ok=True)
|
|
|
50 |
|
51 |
# --- Global Variables for Loaded Models ---
|
52 |
# Use dictionaries to potentially hold multiple models of each type later
|
53 |
+
enhancement_models: Dict[str, Any] = {} # Store model/processor or pipeline
|
54 |
+
separation_models: Dict[str, Any] = {} # Store model/processor or pipeline
|
55 |
|
56 |
# Target sampling rates for models (check model cards on Hugging Face!)
|
57 |
+
# These MUST match the models being loaded in download_models.py and load_hf_models
|
58 |
+
ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement"
|
59 |
+
ENHANCEMENT_SR = 16000 # Sepformer uses 16kHz
|
60 |
|
61 |
+
# Note: facebook/demucs is deprecated in transformers >4.26. Use specific variants.
|
62 |
+
# Using facebook/htdemucs_ft for example (requires Demucs v4 style loading)
|
63 |
+
# Or find a model suitable for AutoModel if needed.
|
64 |
+
SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Example using a quantized version (smaller, faster CPU)
|
65 |
+
# SEPARATION_MODEL_ID = "facebook/hdemucs_mmi" # Example for Multi-Media Instructions model (if using demucs lib)
|
66 |
+
DEMUCS_SR = 44100 # Demucs default is 44.1kHz
|
67 |
|
68 |
+
# Define HF_HOME cache directory *within* the container if downloading during build
|
69 |
+
HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache") # Use HF_HOME from Dockerfile or default
|
70 |
+
|
71 |
+
|
72 |
+
# --- Helper Functions (cleanup_file, save_upload_file, load_audio_for_hf, save_hf_audio) ---
|
73 |
+
# (Include the helper functions from the previous app.py example here)
|
74 |
+
# ...
|
75 |
def cleanup_file(file_path: str):
|
76 |
"""Safely remove a file."""
|
77 |
try:
|
|
|
83 |
|
84 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
85 |
"""Saves an uploaded file to a temporary location and returns the path."""
|
|
|
86 |
_, file_extension = os.path.splitext(upload_file.filename)
|
87 |
if not file_extension: file_extension = ".wav" # Default if no extension
|
88 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
|
|
98 |
finally:
|
99 |
await upload_file.close()
|
100 |
|
|
|
|
|
101 |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]:
|
102 |
"""Loads audio using soundfile, converts to mono float32, optionally resamples."""
|
103 |
+
if not AI_LIBRARIES_AVAILABLE or sf is None or np is None:
|
104 |
+
raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
|
105 |
try:
|
106 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
107 |
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
|
108 |
|
109 |
+
if audio.ndim > 1 and audio.shape[-1] > 1: # Check last dimension for channels
|
110 |
+
if audio.shape[0] == min(audio.shape): # If channels are first dim
|
111 |
+
audio = audio.T # Transpose to (samples, channels)
|
112 |
+
audio = np.mean(audio, axis=1)
|
113 |
+
logger.info(f"Converted audio to mono, new shape: {audio.shape}")
|
114 |
+
elif audio.ndim > 1: # If shape is like (1, N) or (N, 1)
|
115 |
+
audio = audio.squeeze() # Remove singleton dimension
|
116 |
+
logger.info(f"Squeezed audio to 1D, new shape: {audio.shape}")
|
117 |
+
|
118 |
|
|
|
119 |
if target_sr and orig_sr != target_sr:
|
120 |
if librosa is None:
|
121 |
raise RuntimeError("Librosa is required for resampling but not installed.")
|
122 |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
|
123 |
+
# Ensure audio is contiguous before resampling if necessary
|
124 |
+
if not audio.flags['C_CONTIGUOUS']:
|
125 |
+
audio = np.ascontiguousarray(audio)
|
126 |
+
audio = librosa.resample(y=audio, orig_sr=orig_sr, target_sr=target_sr)
|
127 |
logger.info(f"Resampled audio shape: {audio.shape}")
|
128 |
current_sr = target_sr
|
129 |
else:
|
|
|
137 |
|
138 |
def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str:
|
139 |
"""Saves a NumPy audio array to a temporary file."""
|
140 |
+
if not AI_LIBRARIES_AVAILABLE or sf is None or np is None:
|
141 |
+
raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
|
142 |
+
|
143 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
|
144 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
145 |
try:
|
146 |
+
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format}, shape={audio_data.shape})")
|
147 |
+
|
148 |
+
# Ensure data is float32 for common formats like wav/flac
|
149 |
if audio_data.dtype != np.float32:
|
150 |
+
logger.warning(f"Audio data has dtype {audio_data.dtype}, converting to float32.")
|
151 |
audio_data = audio_data.astype(np.float32)
|
152 |
|
153 |
+
# Clip data to avoid issues with some formats/players if values go beyond [-1, 1]
|
154 |
+
audio_data = np.clip(audio_data, -1.0, 1.0)
|
155 |
+
|
156 |
# Use soundfile for lossless formats
|
157 |
if output_format.lower() in ['wav', 'flac']:
|
158 |
sf.write(output_path, audio_data, sampling_rate, format=output_format.upper())
|
159 |
else:
|
160 |
# For lossy formats like mp3, use pydub after converting numpy array
|
161 |
+
logger.debug("Using pydub for lossy format export...")
|
162 |
+
# Scale float32 [-1, 1] to int16 for pydub
|
163 |
audio_int16 = (audio_data * 32767).astype(np.int16)
|
164 |
+
if audio_int16.ndim > 1: # Should be mono by now, but double check
|
165 |
+
logger.warning("Audio data still has multiple dimensions before pydub export, attempting mean.")
|
166 |
+
audio_int16 = np.mean(audio_int16, axis=1).astype(np.int16)
|
167 |
+
|
168 |
segment = AudioSegment(
|
169 |
audio_int16.tobytes(),
|
170 |
frame_rate=sampling_rate,
|
|
|
179 |
cleanup_file(output_path)
|
180 |
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
181 |
|
182 |
+
# --- Synchronous AI Inference Functions (_run_enhancement_sync, _run_separation_sync) ---
|
183 |
+
# (Include the sync functions from the previous app.py example here)
|
184 |
+
# Make sure they handle potential model loading issues gracefully
|
185 |
+
# ...
|
186 |
+
def _run_enhancement_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray:
|
187 |
"""Synchronous wrapper for enhancement model inference."""
|
188 |
+
if not AI_LIBRARIES_AVAILABLE or model_key not in enhancement_models:
|
189 |
+
raise ValueError(f"Enhancement model '{model_key}' not available or AI libraries missing.")
|
190 |
+
|
191 |
+
model_info = enhancement_models[model_key]
|
192 |
+
# Adapt based on whether model_info holds a pipeline or model/processor
|
193 |
+
# This example assumes a pipeline-like object is stored
|
194 |
+
enhancer = model_info # Assuming pipeline
|
195 |
+
if not enhancer: raise ValueError(f"Enhancement pipeline '{model_key}' is None.")
|
196 |
+
|
197 |
try:
|
198 |
+
logger.info(f"Running speech enhancement with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...")
|
199 |
+
# Usage depends heavily on the specific model/pipeline interface
|
200 |
+
# For SpeechBrain models often used *without* HF pipeline:
|
201 |
+
# Example: enhanced_wav = enhancer.enhance_batch(torch.tensor(audio_data).unsqueeze(0), lengths=torch.tensor([audio_data.shape[0]]))
|
202 |
+
# enhanced_audio = enhanced_wav.squeeze(0).cpu().numpy()
|
203 |
+
|
204 |
+
# If using a generic HF pipeline:
|
205 |
+
result = enhancer({"raw": audio_data, "sampling_rate": sampling_rate})
|
206 |
enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output
|
207 |
+
|
208 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
209 |
return enhanced_audio
|
210 |
except Exception as e:
|
211 |
+
logger.error(f"Error during synchronous enhancement inference with '{model_key}': {e}", exc_info=True)
|
212 |
raise # Re-raise to be caught by the async wrapper
|
213 |
|
214 |
+
|
215 |
+
def _run_separation_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]:
|
216 |
"""Synchronous wrapper for source separation model inference."""
|
217 |
+
if not AI_LIBRARIES_AVAILABLE or model_key not in separation_models:
|
218 |
+
raise ValueError(f"Separation model '{model_key}' not available or AI libraries missing.")
|
219 |
+
|
220 |
+
model_info = separation_models[model_key]
|
221 |
+
model = model_info # Assuming direct model object is stored for Demucs
|
222 |
+
|
223 |
+
if not model: raise ValueError(f"Separation model '{model_key}' is None.")
|
224 |
+
|
225 |
try:
|
226 |
+
logger.info(f"Running source separation with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...")
|
227 |
+
|
228 |
+
# Prepare input tensor for Demucs-like models
|
229 |
+
# Expects (batch, channels, samples), float32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
if audio_data.ndim == 1:
|
231 |
+
# Need stereo for standard Demucs
|
232 |
+
logger.debug("Separation input is mono, duplicating to create stereo.")
|
233 |
+
audio_data = np.stack([audio_data, audio_data], axis=0) # (2, samples)
|
234 |
+
if audio_data.shape[0] != 2:
|
235 |
+
# If it's somehow (samples, 2), transpose
|
236 |
+
if audio_data.shape[1] == 2: audio_data = audio_data.T
|
237 |
+
else: raise ValueError(f"Unexpected input audio shape for separation: {audio_data.shape}")
|
238 |
+
|
239 |
+
audio_tensor = torch.tensor(audio_data, dtype=torch.float32).unsqueeze(0) # (1, 2, samples)
|
240 |
|
241 |
+
# Move to model's device (CPU or GPU)
|
242 |
device = next(model.parameters()).device
|
243 |
+
logger.debug(f"Moving separation tensor to device: {device}")
|
244 |
audio_tensor = audio_tensor.to(device)
|
245 |
|
246 |
+
# Perform inference
|
247 |
with torch.no_grad():
|
248 |
+
logger.debug("Starting model inference for separation...")
|
249 |
+
# Output shape depends on model, e.g., (batch, stems, channels, samples)
|
250 |
+
sources = model(audio_tensor)[0] # Remove batch dim
|
251 |
+
logger.debug(f"Model inference complete, sources shape: {sources.shape}")
|
252 |
|
253 |
# Detach, move to CPU, convert to numpy
|
254 |
+
sources_np = sources.detach().cpu().numpy() # (stems, channels, samples)
|
255 |
+
|
256 |
+
# Define stem order based on the *specific* Demucs model used
|
257 |
+
# This order is for default Demucs v3/v4 (facebook/demucs, facebook/htdemucs_ft, etc.)
|
258 |
+
stem_names = ['drums', 'bass', 'other', 'vocals']
|
259 |
+
if sources_np.shape[0] != len(stem_names):
|
260 |
+
logger.warning(f"Model output {sources_np.shape[0]} stems, expected {len(stem_names)}. Stem names might be incorrect.")
|
261 |
+
# Fallback names if shape mismatch
|
262 |
+
stem_names = [f"stem_{i+1}" for i in range(sources_np.shape[0])]
|
263 |
+
|
264 |
+
stems = {}
|
265 |
+
for i, name in enumerate(stem_names):
|
266 |
+
# Average channels to get mono stem
|
267 |
+
mono_stem = np.mean(sources_np[i], axis=0)
|
268 |
+
stems[name] = mono_stem
|
269 |
+
logger.debug(f"Extracted stem '{name}', shape: {mono_stem.shape}")
|
270 |
+
|
271 |
logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
|
272 |
return stems
|
273 |
|
274 |
except Exception as e:
|
275 |
+
logger.error(f"Error during synchronous separation inference with '{model_key}': {e}", exc_info=True)
|
276 |
raise
|
277 |
|
278 |
# --- Model Loading Function ---
|
279 |
+
# (Include the load_hf_models function from the previous app.py example here)
|
280 |
+
# Make sure it uses the correct model IDs and potentially adjusts loading logic
|
281 |
+
# if using libraries like `demucs` directly.
|
282 |
+
# ...
|
283 |
def load_hf_models():
|
284 |
"""Loads Hugging Face models at startup."""
|
285 |
+
if not AI_LIBRARIES_AVAILABLE:
|
286 |
+
logger.warning("Skipping Hugging Face model loading as libraries are missing.")
|
|
|
287 |
return
|
288 |
|
289 |
+
global enhancement_models, separation_models
|
290 |
+
|
291 |
# --- Load Enhancement Model ---
|
292 |
+
enhancement_key = "speechbrain_enhancer" # Internal key
|
|
|
|
|
293 |
try:
|
294 |
+
logger.info(f"Attempting to load enhancement model: {ENHANCEMENT_MODEL_ID}...")
|
295 |
+
# SpeechBrain models often require specific loading from their toolkit or HF spaces
|
296 |
+
# This might involve cloning a repo or using specific classes.
|
297 |
+
# Using HF pipeline if available, otherwise manual load.
|
298 |
+
# Example using pipeline (might not work for all speechbrain models):
|
299 |
+
# enhancement_models[enhancement_key] = pipeline(
|
300 |
+
# "audio-enhancement", # Or appropriate task
|
301 |
+
# model=ENHANCEMENT_MODEL_ID,
|
302 |
+
# cache_dir=HF_CACHE_DIR,
|
303 |
+
# device=0 if torch.cuda.is_available() else -1 # Use GPU if possible
|
304 |
+
# )
|
305 |
+
# Manual load might be needed:
|
306 |
+
# from speechbrain.pretrained import SepformerEnhancement
|
307 |
+
# enhancer = SepformerEnhancement.from_hparams(
|
308 |
+
# source=ENHANCEMENT_MODEL_ID,
|
309 |
+
# savedir=os.path.join(HF_CACHE_DIR, "speechbrain", ENHANCEMENT_MODEL_ID.split('/')[-1]),
|
310 |
+
# run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"}
|
311 |
+
# )
|
312 |
+
# enhancement_models[enhancement_key] = enhancer
|
313 |
+
logger.warning(f"Actual loading for {ENHANCEMENT_MODEL_ID} skipped - requires SpeechBrain toolkit or specific pipeline setup.")
|
314 |
+
# To make the endpoint testable without full setup:
|
315 |
+
# enhancement_models[enhancement_key] = None # Or a dummy function
|
316 |
|
317 |
except Exception as e:
|
318 |
+
logger.error(f"Failed to load enhancement model '{ENHANCEMENT_MODEL_ID}': {e}", exc_info=False)
|
319 |
|
320 |
|
321 |
# --- Load Separation Model (Demucs) ---
|
322 |
+
separation_key = "demucs_separator" # Internal key
|
|
|
323 |
try:
|
324 |
+
logger.info(f"Attempting to load separation model: {SEPARATION_MODEL_ID}...")
|
325 |
+
# Loading Demucs models can be complex.
|
326 |
+
# Option 1: Use AutoModel if the HF Hub version supports it (less common for Demucs)
|
327 |
+
# Option 2: Use the `demucs` library (recommended if installed: pip install -U demucs)
|
328 |
+
# Option 3: Find a Transformers-compatible version if available.
|
329 |
+
|
330 |
+
# Example using AutoModel (Try this first, might work for some quantized/HF versions)
|
331 |
+
try:
|
332 |
+
# Determine device
|
333 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
334 |
+
logger.info(f"Loading Demucs on device: {device}")
|
335 |
+
# Check if AutoModelForSpeechSeq2Seq is appropriate, might need a different AutoModel class
|
336 |
+
separation_models[separation_key] = AutoModelForSpeechSeq2Seq.from_pretrained(
|
337 |
+
SEPARATION_MODEL_ID,
|
338 |
+
cache_dir=HF_CACHE_DIR
|
339 |
+
# Add trust_remote_code=True if needed for custom model code on HF hub
|
340 |
+
).to(device)
|
341 |
+
|
342 |
+
# Check if the loaded model has an 'eval' method (common for PyTorch models)
|
343 |
+
if hasattr(separation_models[separation_key], 'eval'):
|
344 |
+
separation_models[separation_key].eval() # Set to evaluation mode
|
345 |
+
|
346 |
+
logger.info(f"Successfully loaded separation model '{SEPARATION_MODEL_ID}' using AutoModel.")
|
347 |
+
|
348 |
+
except Exception as auto_model_err:
|
349 |
+
logger.warning(f"Failed to load '{SEPARATION_MODEL_ID}' using AutoModel: {auto_model_err}. Consider installing 'demucs' library.")
|
350 |
+
separation_models[separation_key] = None # Ensure it's None if loading failed
|
351 |
+
|
352 |
+
# Example using `demucs` library (if installed)
|
353 |
+
# try:
|
354 |
+
# import demucs.separate
|
355 |
+
# model = demucs.apply.load_model(pretrained_model_path_or_url) # Needs adjustment
|
356 |
+
# separation_models[separation_key] = model
|
357 |
+
# logger.info(f"Successfully loaded separation model using 'demucs' library.")
|
358 |
+
# except ImportError:
|
359 |
+
# logger.error("Cannot load Demucs: 'demucs' library not found. Please run 'pip install -U demucs'.")
|
360 |
+
# except Exception as demucs_lib_err:
|
361 |
+
# logger.error(f"Error loading model using 'demucs' library: {demucs_lib_err}")
|
362 |
|
|
|
|
|
|
|
363 |
|
364 |
except Exception as e:
|
365 |
+
logger.error(f"General error loading separation model '{SEPARATION_MODEL_ID}': {e}", exc_info=False)
|
366 |
+
if separation_key in separation_models: separation_models[separation_key] = None
|
367 |
|
368 |
|
369 |
# --- FastAPI App and Endpoints ---
|
|
|
376 |
@app.on_event("startup")
|
377 |
async def startup_event():
|
378 |
"""Load models when the application starts."""
|
379 |
+
logger.info("Application startup: Loading AI models (this may take time)...")
|
|
|
|
|
|
|
380 |
await asyncio.to_thread(load_hf_models)
|
381 |
+
logger.info("Model loading process finished.")
|
|
|
382 |
|
|
|
383 |
|
384 |
+
# --- API Endpoints ---
|
385 |
+
# (Include / , /trim, /concat, /volume, /convert endpoints here - same as previous version)
|
386 |
+
# ...
|
387 |
@app.get("/", tags=["General"])
|
388 |
def read_root():
|
389 |
"""Root endpoint providing a welcome message and available features."""
|
390 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
391 |
ai_features = []
|
392 |
+
# Check if models were successfully loaded (i.e., not None)
|
393 |
+
if any(model is not None for model in enhancement_models.values()): ai_features.append("/enhance")
|
394 |
+
if any(model is not None for model in separation_models.values()): ai_features.append("/separate")
|
395 |
|
396 |
return {
|
397 |
"message": "Welcome to the AI Audio Editor API.",
|
398 |
"basic_features": features,
|
399 |
+
"ai_features": ai_features if ai_features else "None loaded (check logs)",
|
400 |
"notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)."
|
401 |
}
|
402 |
|
403 |
+
@app.post("/trim", tags=["Basic Editing"])
|
404 |
+
async def trim_audio(
|
405 |
+
background_tasks: BackgroundTasks,
|
406 |
+
file: UploadFile = File(..., description="Audio file to trim."),
|
407 |
+
start_ms: int = Form(..., description="Start time in milliseconds."),
|
408 |
+
end_ms: int = Form(..., description="End time in milliseconds.")
|
409 |
+
):
|
410 |
+
"""Trims an audio file to the specified start and end times (in milliseconds)."""
|
411 |
+
if start_ms < 0 or end_ms <= start_ms:
|
412 |
+
raise HTTPException(status_code=422, detail="Invalid start/end times. Ensure start_ms >= 0 and end_ms > start_ms.")
|
413 |
+
|
414 |
+
logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
|
415 |
+
input_path = None
|
416 |
+
output_path = None
|
417 |
+
try:
|
418 |
+
input_path = await save_upload_file(file, prefix="trim_in_")
|
419 |
+
background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
|
420 |
+
|
421 |
+
# Use Pydub for basic trim
|
422 |
+
audio = AudioSegment.from_file(input_path)
|
423 |
+
trimmed_audio = audio[start_ms:end_ms]
|
424 |
+
logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
|
425 |
+
|
426 |
+
original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
|
427 |
+
if not original_format or original_format == "tmp": original_format = "mp3"
|
428 |
+
output_filename = f"trimmed_{uuid.uuid4().hex}.{original_format}"
|
429 |
+
output_path = os.path.join(TEMP_DIR, output_filename)
|
430 |
+
|
431 |
+
trimmed_audio.export(output_path, format=original_format)
|
432 |
+
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
|
433 |
+
|
434 |
+
return FileResponse(
|
435 |
+
path=output_path,
|
436 |
+
media_type=f"audio/{original_format}", # Attempt correct media type
|
437 |
+
filename=f"trimmed_{file.filename}"
|
438 |
+
)
|
439 |
+
except CouldntDecodeError:
|
440 |
+
logger.warning(f"pydub failed to decode: {file.filename}")
|
441 |
+
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
|
442 |
+
except Exception as e:
|
443 |
+
logger.error(f"Error during trim operation: {e}", exc_info=True)
|
444 |
+
if output_path: cleanup_file(output_path) # Immediate cleanup on error
|
445 |
+
if input_path: cleanup_file(input_path) # Immediate cleanup on error
|
446 |
+
if isinstance(e, HTTPException): raise e
|
447 |
+
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}")
|
448 |
+
|
449 |
+
|
450 |
+
@app.post("/concat", tags=["Basic Editing"])
|
451 |
+
async def concatenate_audio(
|
452 |
+
background_tasks: BackgroundTasks,
|
453 |
+
files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
|
454 |
+
output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
|
455 |
+
):
|
456 |
+
"""Concatenates two or more audio files sequentially."""
|
457 |
+
if len(files) < 2:
|
458 |
+
raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
|
459 |
+
|
460 |
+
logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
|
461 |
+
input_paths = []
|
462 |
+
loaded_audios = []
|
463 |
+
output_path = None
|
464 |
+
|
465 |
+
try:
|
466 |
+
combined_audio = AudioSegment.empty()
|
467 |
+
first_filename_base = "combined"
|
468 |
|
469 |
+
for i, file in enumerate(files):
|
470 |
+
input_path = await save_upload_file(file, prefix=f"concat_{i}_")
|
471 |
+
input_paths.append(input_path)
|
472 |
+
background_tasks.add_task(cleanup_file, input_path)
|
473 |
+
audio = AudioSegment.from_file(input_path)
|
474 |
+
combined_audio += audio
|
475 |
+
if i == 0: first_filename_base = os.path.splitext(file.filename)[0]
|
476 |
+
logger.info(f"Appended '{file.filename}', current total duration: {len(combined_audio)}ms")
|
477 |
|
478 |
+
if len(combined_audio) == 0:
|
479 |
+
raise HTTPException(status_code=500, detail="No audio data after attempting concatenation.")
|
480 |
|
481 |
+
output_filename_final = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
|
482 |
+
output_path = os.path.join(TEMP_DIR, f"concat_out_{uuid.uuid4().hex}.{output_format}")
|
483 |
+
|
484 |
+
combined_audio.export(output_path, format=output_format)
|
485 |
+
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
|
486 |
+
|
487 |
+
return FileResponse(
|
488 |
+
path=output_path,
|
489 |
+
media_type=f"audio/{output_format}",
|
490 |
+
filename=output_filename_final
|
491 |
+
)
|
492 |
+
except CouldntDecodeError as e:
|
493 |
+
logger.warning(f"pydub failed to decode one of the concat files: {e}")
|
494 |
+
raise HTTPException(status_code=415, detail=f"Unsupported format or corrupted file among inputs: {e}")
|
495 |
+
except Exception as e:
|
496 |
+
logger.error(f"Error during concat operation: {e}", exc_info=True)
|
497 |
+
if output_path: cleanup_file(output_path)
|
498 |
+
for p in input_paths: cleanup_file(p)
|
499 |
+
if isinstance(e, HTTPException): raise e
|
500 |
+
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}")
|
501 |
+
|
502 |
+
@app.post("/volume", tags=["Basic Editing"])
|
503 |
+
async def change_volume(
|
504 |
+
background_tasks: BackgroundTasks,
|
505 |
+
file: UploadFile = File(..., description="Audio file to adjust volume for."),
|
506 |
+
change_db: float = Form(..., description="Volume change in decibels (dB). Positive values increase volume, negative values decrease.")
|
507 |
+
):
|
508 |
+
"""Adjusts the volume of an audio file by a specified decibel amount."""
|
509 |
+
logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
|
510 |
+
input_path = None
|
511 |
+
output_path = None
|
512 |
+
try:
|
513 |
+
input_path = await save_upload_file(file, prefix="volume_in_")
|
514 |
+
background_tasks.add_task(cleanup_file, input_path)
|
515 |
+
|
516 |
+
audio = AudioSegment.from_file(input_path)
|
517 |
+
adjusted_audio = audio + change_db
|
518 |
+
logger.info(f"Volume adjusted by {change_db}dB.")
|
519 |
+
|
520 |
+
original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
|
521 |
+
if not original_format or original_format == "tmp": original_format = "mp3"
|
522 |
+
output_filename_final = f"volume_{change_db}dB_{file.filename}"
|
523 |
+
output_path = os.path.join(TEMP_DIR, f"volume_out_{uuid.uuid4().hex}.{original_format}")
|
524 |
+
|
525 |
+
adjusted_audio.export(output_path, format=original_format)
|
526 |
+
background_tasks.add_task(cleanup_file, output_path)
|
527 |
+
|
528 |
+
return FileResponse(
|
529 |
+
path=output_path,
|
530 |
+
media_type=f"audio/{original_format}",
|
531 |
+
filename=output_filename_final
|
532 |
+
)
|
533 |
+
except CouldntDecodeError:
|
534 |
+
logger.warning(f"pydub failed to decode: {file.filename}")
|
535 |
+
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
|
536 |
+
except Exception as e:
|
537 |
+
logger.error(f"Error during volume operation: {e}", exc_info=True)
|
538 |
+
if output_path: cleanup_file(output_path)
|
539 |
+
if input_path: cleanup_file(input_path)
|
540 |
+
if isinstance(e, HTTPException): raise e
|
541 |
+
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}")
|
542 |
+
|
543 |
+
@app.post("/convert", tags=["Basic Editing"])
|
544 |
+
async def convert_format(
|
545 |
+
background_tasks: BackgroundTasks,
|
546 |
+
file: UploadFile = File(..., description="Audio file to convert."),
|
547 |
+
output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
|
548 |
+
):
|
549 |
+
"""Converts an audio file to a different format."""
|
550 |
+
allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a'}
|
551 |
+
if output_format.lower() not in allowed_formats:
|
552 |
+
raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed: {', '.join(allowed_formats)}")
|
553 |
+
|
554 |
+
logger.info(f"Convert request: file='{file.filename}', output_format='{output_format}'")
|
555 |
+
input_path = None
|
556 |
+
output_path = None
|
557 |
+
try:
|
558 |
+
input_path = await save_upload_file(file, prefix="convert_in_")
|
559 |
+
background_tasks.add_task(cleanup_file, input_path)
|
560 |
+
|
561 |
+
audio = AudioSegment.from_file(input_path)
|
562 |
+
|
563 |
+
output_format_lower = output_format.lower()
|
564 |
+
filename_base = os.path.splitext(file.filename)[0]
|
565 |
+
output_filename_final = f"{filename_base}_converted.{output_format_lower}"
|
566 |
+
output_path = os.path.join(TEMP_DIR, f"convert_out_{uuid.uuid4().hex}.{output_format_lower}")
|
567 |
+
|
568 |
+
audio.export(output_path, format=output_format_lower)
|
569 |
+
background_tasks.add_task(cleanup_file, output_path)
|
570 |
+
|
571 |
+
return FileResponse(
|
572 |
+
path=output_path,
|
573 |
+
media_type=f"audio/{output_format_lower}",
|
574 |
+
filename=output_filename_final
|
575 |
+
)
|
576 |
+
except CouldntDecodeError:
|
577 |
+
logger.warning(f"pydub failed to decode: {file.filename}")
|
578 |
+
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
|
579 |
+
except Exception as e:
|
580 |
+
logger.error(f"Error during convert operation: {e}", exc_info=True)
|
581 |
+
if output_path: cleanup_file(output_path)
|
582 |
+
if input_path: cleanup_file(input_path)
|
583 |
+
if isinstance(e, HTTPException): raise e
|
584 |
+
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}")
|
585 |
+
|
586 |
+
|
587 |
+
# (Include /enhance and /separate AI endpoints here - same as previous version)
|
588 |
+
# ...
|
589 |
@app.post("/enhance", tags=["AI Editing"])
|
590 |
async def enhance_speech(
|
591 |
background_tasks: BackgroundTasks,
|
592 |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
|
593 |
+
model_key: str = Query("speechbrain_enhancer", description="Internal key of the enhancement model to use."),
|
594 |
output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).")
|
595 |
):
|
596 |
"""Enhances speech audio using a pre-loaded AI model (experimental)."""
|
597 |
+
if not AI_LIBRARIES_AVAILABLE:
|
598 |
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
599 |
+
if model_key not in enhancement_models or enhancement_models[model_key] is None:
|
600 |
+
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
|
601 |
+
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
|
602 |
|
603 |
+
logger.info(f"Enhance request: file='{file.filename}', model_key='{model_key}', format='{output_format}'")
|
604 |
+
input_path = None
|
605 |
+
output_path = None
|
|
|
|
|
606 |
try:
|
607 |
+
input_path = await save_upload_file(file, prefix="enhance_in_")
|
608 |
+
background_tasks.add_task(cleanup_file, input_path)
|
609 |
+
|
610 |
# Load audio, ensure correct SR for the model
|
611 |
+
logger.debug(f"Loading audio for enhancement, target SR: {ENHANCEMENT_SR}")
|
612 |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
613 |
+
if current_sr != ENHANCEMENT_SR: # Should have been resampled, but double check
|
614 |
+
logger.warning(f"Audio SR after loading is {current_sr}, expected {ENHANCEMENT_SR}. Check resampling.")
|
615 |
+
# Depending on model strictness, could raise error or proceed cautiously.
|
616 |
+
# raise HTTPException(status_code=500, detail="Audio resampling failed.")
|
617 |
|
618 |
# Run inference in a separate thread
|
619 |
logger.info("Submitting enhancement task to background thread...")
|
|
|
620 |
enhanced_audio = await asyncio.to_thread(
|
621 |
+
_run_enhancement_sync, model_key, audio_data, current_sr # Pass key, data, and ACTUAL sr used
|
622 |
)
|
623 |
logger.info("Enhancement task completed.")
|
624 |
|
625 |
# Save the result
|
626 |
+
output_path = save_hf_audio(enhanced_audio, ENHANCEMENT_SR, output_format) # Save with model's target SR
|
627 |
background_tasks.add_task(cleanup_file, output_path)
|
628 |
|
629 |
+
output_filename_final = f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
|
630 |
return FileResponse(
|
631 |
path=output_path,
|
632 |
media_type=f"audio/{output_format}",
|
633 |
+
filename=output_filename_final
|
634 |
)
|
635 |
|
636 |
except Exception as e:
|
637 |
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
|
638 |
+
if output_path: cleanup_file(output_path)
|
639 |
+
if input_path: cleanup_file(input_path)
|
640 |
if isinstance(e, HTTPException): raise e
|
641 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
|
642 |
|
|
|
645 |
async def separate_sources(
|
646 |
background_tasks: BackgroundTasks,
|
647 |
file: UploadFile = File(..., description="Music audio file to separate into stems."),
|
648 |
+
model_key: str = Query("demucs_separator", description="Internal key of the separation model to use."),
|
649 |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
|
650 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
651 |
):
|
652 |
+
"""Separates music into stems (vocals, drums, bass, other) using a pre-loaded AI model (experimental). Returns a ZIP archive."""
|
653 |
+
if not AI_LIBRARIES_AVAILABLE:
|
654 |
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
655 |
+
if model_key not in separation_models or separation_models[model_key] is None:
|
656 |
+
logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
|
657 |
+
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
|
658 |
|
659 |
+
valid_stems = {'vocals', 'drums', 'bass', 'other'} # Based on typical Demucs output
|
660 |
requested_stems = set(s.lower() for s in stems)
|
661 |
if not requested_stems.issubset(valid_stems):
|
662 |
+
# Allow if all stems are requested even if validation set is smaller? Or just error.
|
663 |
+
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are generally: {', '.join(valid_stems)}")
|
|
|
|
|
|
|
664 |
|
665 |
+
logger.info(f"Separate request: file='{file.filename}', model_key='{model_key}', stems={requested_stems}, format='{output_format}'")
|
666 |
+
input_path = None
|
667 |
stem_output_paths: Dict[str, str] = {}
|
668 |
+
zip_buffer = io.BytesIO() # Use BytesIO for in-memory ZIP
|
669 |
+
|
670 |
try:
|
671 |
+
input_path = await save_upload_file(file, prefix="separate_in_")
|
672 |
+
background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
|
673 |
+
|
674 |
+
# Load audio, ensure correct SR for the model
|
675 |
+
logger.debug(f"Loading audio for separation, target SR: {DEMUCS_SR}")
|
676 |
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
677 |
+
if current_sr != DEMUCS_SR:
|
678 |
+
logger.warning(f"Audio SR after loading is {current_sr}, expected {DEMUCS_SR}. Check resampling.")
|
679 |
+
# raise HTTPException(status_code=500, detail="Audio resampling failed.")
|
680 |
|
681 |
# Run inference in a separate thread
|
682 |
logger.info("Submitting separation task to background thread...")
|
|
|
683 |
all_separated_stems = await asyncio.to_thread(
|
684 |
+
_run_separation_sync, model_key, audio_data, current_sr # Pass key, data, actual SR
|
685 |
)
|
686 |
logger.info("Separation task completed.")
|
687 |
|
688 |
# --- Create ZIP file in memory ---
|
689 |
+
zip_filename_base = f"separated_{os.path.splitext(file.filename)[0]}"
|
690 |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
691 |
+
logger.info(f"Creating ZIP archive in memory...")
|
692 |
+
found_stems_count = 0
|
693 |
for stem_name in requested_stems:
|
694 |
if stem_name in all_separated_stems:
|
695 |
stem_data = all_separated_stems[stem_name]
|
696 |
+
if stem_data is None or stem_data.size == 0:
|
697 |
+
logger.warning(f"Stem '{stem_name}' data is empty, skipping.")
|
698 |
+
continue
|
699 |
+
|
700 |
# Save stem temporarily to disk first (needed for pydub/sf.write)
|
701 |
+
logger.debug(f"Saving temporary stem file for '{stem_name}'...")
|
702 |
+
stem_path = save_hf_audio(stem_data, DEMUCS_SR, output_format) # Save with model's target SR
|
703 |
stem_output_paths[stem_name] = stem_path # Keep track for cleanup
|
704 |
background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
|
705 |
|
706 |
# Add the saved stem file to the ZIP archive
|
707 |
+
archive_name = f"{stem_name}.{output_format}" # Simple name inside zip
|
708 |
zipf.write(stem_path, arcname=archive_name)
|
709 |
logger.info(f"Added '{archive_name}' to ZIP.")
|
710 |
+
found_stems_count += 1
|
711 |
else:
|
712 |
+
logger.warning(f"Requested stem '{stem_name}' not found in model output keys: {list(all_separated_stems.keys())}")
|
713 |
+
|
714 |
+
if found_stems_count == 0:
|
715 |
+
raise HTTPException(status_code=404, detail="None of the requested stems were found or generated successfully.")
|
716 |
|
717 |
zip_buffer.seek(0) # Rewind buffer pointer
|
718 |
|
719 |
+
# Return the ZIP file via StreamingResponse
|
720 |
+
zip_filename_download = f"{zip_filename_base}.zip"
|
721 |
+
logger.info(f"Sending ZIP file '{zip_filename_download}'")
|
722 |
return StreamingResponse(
|
723 |
+
zip_buffer, # Pass the BytesIO buffer directly
|
724 |
media_type="application/zip",
|
725 |
+
headers={'Content-Disposition': f'attachment; filename="{zip_filename_download}"'}
|
726 |
)
|
727 |
|
728 |
except Exception as e:
|
729 |
logger.error(f"Error during separation operation: {e}", exc_info=True)
|
730 |
+
# Cleanup temporary stem files if they exist
|
731 |
+
for path in stem_output_paths.values(): cleanup_file(path)
|
732 |
+
# Close buffer just in case (though StreamingResponse should handle it)
|
733 |
+
# if zip_buffer and not zip_buffer.closed: zip_buffer.close()
|
734 |
+
if input_path: cleanup_file(input_path)
|
735 |
|
736 |
if isinstance(e, HTTPException): raise e
|
737 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
|
738 |
|
739 |
|
740 |
+
# ----------- END app.py -----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|