Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
-
# ----------- START app.py -----------
|
2 |
import os
|
3 |
import uuid
|
4 |
import tempfile
|
5 |
import logging
|
6 |
import asyncio
|
7 |
from typing import List, Optional, Dict, Any
|
8 |
-
import io
|
9 |
-
import zipfile
|
10 |
|
11 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
|
12 |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
|
|
|
|
13 |
|
14 |
# --- Basic Editing Imports ---
|
15 |
from pydub import AudioSegment
|
@@ -18,62 +17,58 @@ from pydub.exceptions import CouldntDecodeError
|
|
18 |
# --- AI & Advanced Audio Imports ---
|
19 |
try:
|
20 |
import torch
|
21 |
-
|
22 |
-
#
|
23 |
-
# E.g. for Demucs V4 (Hybrid Transformer): from demucs.hdemucs import HDemucs
|
24 |
-
# from demucs.pretrained import hdemucs_mmi
|
25 |
import soundfile as sf
|
26 |
import numpy as np
|
27 |
-
import librosa
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
print("AI and advanced audio libraries loaded.")
|
30 |
except ImportError as e:
|
31 |
-
print(f"
|
32 |
-
print("Ensure torch,
|
33 |
print("AI features will be unavailable.")
|
34 |
-
AI_LIBRARIES_AVAILABLE = False
|
35 |
-
# Define dummy placeholders if needed, or just rely on AI_LIBRARIES_AVAILABLE flag
|
36 |
torch = None
|
37 |
-
pipeline = None
|
38 |
sf = None
|
39 |
np = None
|
40 |
librosa = None
|
41 |
-
|
|
|
42 |
|
43 |
# --- Configuration & Setup ---
|
44 |
TEMP_DIR = tempfile.gettempdir()
|
45 |
os.makedirs(TEMP_DIR, exist_ok=True)
|
46 |
|
47 |
-
# Configure logging
|
48 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
49 |
logger = logging.getLogger(__name__)
|
50 |
|
51 |
# --- Global Variables for Loaded Models ---
|
52 |
-
# Use
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
# Target sampling rates for models (check model cards on Hugging Face!)
|
57 |
-
# These MUST match the models being loaded in download_models.py and load_hf_models
|
58 |
-
ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement"
|
59 |
-
ENHANCEMENT_SR = 16000 # Sepformer uses 16kHz
|
60 |
-
|
61 |
-
# Note: facebook/demucs is deprecated in transformers >4.26. Use specific variants.
|
62 |
-
# Using facebook/htdemucs_ft for example (requires Demucs v4 style loading)
|
63 |
-
# Or find a model suitable for AutoModel if needed.
|
64 |
-
SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Example using a quantized version (smaller, faster CPU)
|
65 |
-
# SEPARATION_MODEL_ID = "facebook/hdemucs_mmi" # Example for Multi-Media Instructions model (if using demucs lib)
|
66 |
-
DEMUCS_SR = 44100 # Demucs default is 44.1kHz
|
67 |
|
68 |
-
|
69 |
-
|
70 |
|
|
|
|
|
|
|
71 |
|
72 |
-
# ---
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
75 |
def cleanup_file(file_path: str):
|
76 |
-
"""Safely remove a file."""
|
77 |
try:
|
78 |
if file_path and os.path.exists(file_path):
|
79 |
os.remove(file_path)
|
@@ -82,9 +77,8 @@ def cleanup_file(file_path: str):
|
|
82 |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
|
83 |
|
84 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
85 |
-
"""Saves an uploaded file to a temporary location and returns the path."""
|
86 |
_, file_extension = os.path.splitext(upload_file.filename)
|
87 |
-
if not file_extension: file_extension = ".wav"
|
88 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
89 |
try:
|
90 |
with open(temp_file_path, "wb") as buffer:
|
@@ -98,78 +92,83 @@ async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") ->
|
|
98 |
finally:
|
99 |
await upload_file.close()
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
|
105 |
try:
|
106 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
107 |
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
|
108 |
|
109 |
-
if audio.ndim > 1 and audio.shape[
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
elif audio.ndim > 1: # If shape is like (1, N) or (N, 1)
|
115 |
-
audio = audio.squeeze() # Remove singleton dimension
|
116 |
-
logger.info(f"Squeezed audio to 1D, new shape: {audio.shape}")
|
117 |
|
|
|
|
|
118 |
|
|
|
119 |
if target_sr and orig_sr != target_sr:
|
120 |
-
if librosa is None:
|
121 |
-
raise RuntimeError("Librosa is required for resampling but not installed.")
|
122 |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
|
123 |
-
#
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
logger.info(f"Resampled audio shape: {audio.shape}")
|
128 |
current_sr = target_sr
|
|
|
129 |
else:
|
130 |
current_sr = orig_sr
|
131 |
|
132 |
-
|
|
|
133 |
|
134 |
except Exception as e:
|
135 |
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
|
136 |
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")
|
137 |
|
138 |
-
def save_hf_audio(audio_data:
|
139 |
-
"""Saves
|
140 |
-
if not AI_LIBRARIES_AVAILABLE or sf is None or np is None:
|
141 |
-
raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
|
142 |
-
|
143 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
|
144 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
145 |
try:
|
146 |
-
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
# Ensure data is float32
|
149 |
-
if
|
150 |
-
logger.warning(f"
|
151 |
-
|
152 |
|
153 |
-
# Clip
|
154 |
-
|
155 |
|
156 |
-
# Use soundfile for
|
157 |
if output_format.lower() in ['wav', 'flac']:
|
158 |
-
sf.write(output_path,
|
159 |
else:
|
160 |
-
# For lossy formats
|
161 |
-
logger.debug("Using pydub
|
162 |
# Scale float32 [-1, 1] to int16 for pydub
|
163 |
-
audio_int16 = (
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
segment = AudioSegment(
|
169 |
audio_int16.tobytes(),
|
170 |
frame_rate=sampling_rate,
|
171 |
sample_width=audio_int16.dtype.itemsize,
|
172 |
-
channels=1 #
|
173 |
)
|
174 |
segment.export(output_path, format=output_format)
|
175 |
|
@@ -179,464 +178,210 @@ def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str
|
|
179 |
cleanup_file(output_path)
|
180 |
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
181 |
|
182 |
-
# --- Synchronous AI Inference Functions
|
183 |
-
# (Include the sync functions from the previous app.py example here)
|
184 |
-
# Make sure they handle potential model loading issues gracefully
|
185 |
-
# ...
|
186 |
-
def _run_enhancement_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray:
|
187 |
-
"""Synchronous wrapper for enhancement model inference."""
|
188 |
-
if not AI_LIBRARIES_AVAILABLE or model_key not in enhancement_models:
|
189 |
-
raise ValueError(f"Enhancement model '{model_key}' not available or AI libraries missing.")
|
190 |
-
|
191 |
-
model_info = enhancement_models[model_key]
|
192 |
-
# Adapt based on whether model_info holds a pipeline or model/processor
|
193 |
-
# This example assumes a pipeline-like object is stored
|
194 |
-
enhancer = model_info # Assuming pipeline
|
195 |
-
if not enhancer: raise ValueError(f"Enhancement pipeline '{model_key}' is None.")
|
196 |
|
|
|
|
|
|
|
197 |
try:
|
198 |
-
logger.info(f"Running speech enhancement
|
199 |
-
#
|
200 |
-
#
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
|
|
|
|
|
208 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
209 |
return enhanced_audio
|
210 |
except Exception as e:
|
211 |
-
logger.error(f"Error during synchronous enhancement inference
|
212 |
-
raise
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
222 |
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
try:
|
226 |
-
logger.info(f"Running source separation with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...")
|
227 |
-
|
228 |
-
# Prepare input tensor for Demucs-like models
|
229 |
-
# Expects (batch, channels, samples), float32
|
230 |
-
if audio_data.ndim == 1:
|
231 |
-
# Need stereo for standard Demucs
|
232 |
-
logger.debug("Separation input is mono, duplicating to create stereo.")
|
233 |
-
audio_data = np.stack([audio_data, audio_data], axis=0) # (2, samples)
|
234 |
-
if audio_data.shape[0] != 2:
|
235 |
-
# If it's somehow (samples, 2), transpose
|
236 |
-
if audio_data.shape[1] == 2: audio_data = audio_data.T
|
237 |
-
else: raise ValueError(f"Unexpected input audio shape for separation: {audio_data.shape}")
|
238 |
-
|
239 |
-
audio_tensor = torch.tensor(audio_data, dtype=torch.float32).unsqueeze(0) # (1, 2, samples)
|
240 |
-
|
241 |
-
# Move to model's device (CPU or GPU)
|
242 |
-
device = next(model.parameters()).device
|
243 |
-
logger.debug(f"Moving separation tensor to device: {device}")
|
244 |
-
audio_tensor = audio_tensor.to(device)
|
245 |
-
|
246 |
-
# Perform inference
|
247 |
with torch.no_grad():
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
#
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
stems[name] = mono_stem
|
269 |
-
logger.debug(f"Extracted stem '{name}', shape: {mono_stem.shape}")
|
270 |
-
|
271 |
-
logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
|
272 |
-
return stems
|
273 |
|
274 |
except Exception as e:
|
275 |
-
logger.error(f"Error during synchronous separation inference
|
276 |
raise
|
277 |
|
278 |
# --- Model Loading Function ---
|
279 |
-
# (Include the load_hf_models function from the previous app.py example here)
|
280 |
-
# Make sure it uses the correct model IDs and potentially adjusts loading logic
|
281 |
-
# if using libraries like `demucs` directly.
|
282 |
-
# ...
|
283 |
def load_hf_models():
|
284 |
-
"""Loads
|
285 |
-
if not AI_LIBRARIES_AVAILABLE:
|
286 |
-
logger.warning("Skipping Hugging Face model loading as libraries are missing.")
|
287 |
-
return
|
288 |
-
|
289 |
global enhancement_models, separation_models
|
|
|
|
|
|
|
290 |
|
291 |
-
# --- Load Enhancement Model ---
|
292 |
-
|
293 |
try:
|
294 |
-
logger.info(f"
|
295 |
-
# SpeechBrain
|
296 |
-
#
|
297 |
-
#
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
# Manual load might be needed:
|
306 |
-
# from speechbrain.pretrained import SepformerEnhancement
|
307 |
-
# enhancer = SepformerEnhancement.from_hparams(
|
308 |
-
# source=ENHANCEMENT_MODEL_ID,
|
309 |
-
# savedir=os.path.join(HF_CACHE_DIR, "speechbrain", ENHANCEMENT_MODEL_ID.split('/')[-1]),
|
310 |
-
# run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"}
|
311 |
-
# )
|
312 |
-
# enhancement_models[enhancement_key] = enhancer
|
313 |
-
logger.warning(f"Actual loading for {ENHANCEMENT_MODEL_ID} skipped - requires SpeechBrain toolkit or specific pipeline setup.")
|
314 |
-
# To make the endpoint testable without full setup:
|
315 |
-
# enhancement_models[enhancement_key] = None # Or a dummy function
|
316 |
-
|
317 |
except Exception as e:
|
318 |
-
logger.error(f"Failed to load enhancement model '{
|
319 |
-
|
320 |
|
321 |
# --- Load Separation Model (Demucs) ---
|
322 |
-
|
|
|
323 |
try:
|
324 |
-
logger.info(f"
|
325 |
-
#
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
# Example using AutoModel (Try this first, might work for some quantized/HF versions)
|
331 |
-
try:
|
332 |
-
# Determine device
|
333 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
334 |
-
logger.info(f"Loading Demucs on device: {device}")
|
335 |
-
# Check if AutoModelForSpeechSeq2Seq is appropriate, might need a different AutoModel class
|
336 |
-
separation_models[separation_key] = AutoModelForSpeechSeq2Seq.from_pretrained(
|
337 |
-
SEPARATION_MODEL_ID,
|
338 |
-
cache_dir=HF_CACHE_DIR
|
339 |
-
# Add trust_remote_code=True if needed for custom model code on HF hub
|
340 |
-
).to(device)
|
341 |
-
|
342 |
-
# Check if the loaded model has an 'eval' method (common for PyTorch models)
|
343 |
-
if hasattr(separation_models[separation_key], 'eval'):
|
344 |
-
separation_models[separation_key].eval() # Set to evaluation mode
|
345 |
-
|
346 |
-
logger.info(f"Successfully loaded separation model '{SEPARATION_MODEL_ID}' using AutoModel.")
|
347 |
-
|
348 |
-
except Exception as auto_model_err:
|
349 |
-
logger.warning(f"Failed to load '{SEPARATION_MODEL_ID}' using AutoModel: {auto_model_err}. Consider installing 'demucs' library.")
|
350 |
-
separation_models[separation_key] = None # Ensure it's None if loading failed
|
351 |
-
|
352 |
-
# Example using `demucs` library (if installed)
|
353 |
-
# try:
|
354 |
-
# import demucs.separate
|
355 |
-
# model = demucs.apply.load_model(pretrained_model_path_or_url) # Needs adjustment
|
356 |
-
# separation_models[separation_key] = model
|
357 |
-
# logger.info(f"Successfully loaded separation model using 'demucs' library.")
|
358 |
-
# except ImportError:
|
359 |
-
# logger.error("Cannot load Demucs: 'demucs' library not found. Please run 'pip install -U demucs'.")
|
360 |
-
# except Exception as demucs_lib_err:
|
361 |
-
# logger.error(f"Error loading model using 'demucs' library: {demucs_lib_err}")
|
362 |
-
|
363 |
-
|
364 |
except Exception as e:
|
365 |
-
logger.error(f"
|
366 |
-
|
367 |
|
368 |
|
369 |
-
# --- FastAPI App
|
370 |
app = FastAPI(
|
371 |
title="AI Audio Editor API",
|
372 |
-
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and
|
373 |
-
version="2.
|
374 |
)
|
375 |
|
376 |
@app.on_event("startup")
|
377 |
async def startup_event():
|
378 |
-
"
|
379 |
-
|
380 |
await asyncio.to_thread(load_hf_models)
|
381 |
-
logger.info("Model loading process finished.")
|
382 |
-
|
383 |
|
384 |
# --- API Endpoints ---
|
385 |
-
|
386 |
-
# ...
|
387 |
@app.get("/", tags=["General"])
|
388 |
def read_root():
|
389 |
-
|
390 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
391 |
ai_features = []
|
392 |
-
|
393 |
-
if
|
394 |
-
if any(model is not None for model in separation_models.values()): ai_features.append("/separate")
|
395 |
|
396 |
return {
|
397 |
"message": "Welcome to the AI Audio Editor API.",
|
398 |
"basic_features": features,
|
399 |
-
"ai_features": ai_features if ai_features else "None
|
400 |
-
"notes": "Requires FFmpeg. AI features require specific models loaded at startup
|
401 |
}
|
402 |
|
403 |
-
@app.post("/trim", tags=["Basic Editing"])
|
404 |
-
async def trim_audio(
|
405 |
-
background_tasks: BackgroundTasks,
|
406 |
-
file: UploadFile = File(..., description="Audio file to trim."),
|
407 |
-
start_ms: int = Form(..., description="Start time in milliseconds."),
|
408 |
-
end_ms: int = Form(..., description="End time in milliseconds.")
|
409 |
-
):
|
410 |
-
"""Trims an audio file to the specified start and end times (in milliseconds)."""
|
411 |
-
if start_ms < 0 or end_ms <= start_ms:
|
412 |
-
raise HTTPException(status_code=422, detail="Invalid start/end times. Ensure start_ms >= 0 and end_ms > start_ms.")
|
413 |
-
|
414 |
-
logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
|
415 |
-
input_path = None
|
416 |
-
output_path = None
|
417 |
-
try:
|
418 |
-
input_path = await save_upload_file(file, prefix="trim_in_")
|
419 |
-
background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
|
420 |
-
|
421 |
-
# Use Pydub for basic trim
|
422 |
-
audio = AudioSegment.from_file(input_path)
|
423 |
-
trimmed_audio = audio[start_ms:end_ms]
|
424 |
-
logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
|
425 |
-
|
426 |
-
original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
|
427 |
-
if not original_format or original_format == "tmp": original_format = "mp3"
|
428 |
-
output_filename = f"trimmed_{uuid.uuid4().hex}.{original_format}"
|
429 |
-
output_path = os.path.join(TEMP_DIR, output_filename)
|
430 |
-
|
431 |
-
trimmed_audio.export(output_path, format=original_format)
|
432 |
-
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
|
433 |
-
|
434 |
-
return FileResponse(
|
435 |
-
path=output_path,
|
436 |
-
media_type=f"audio/{original_format}", # Attempt correct media type
|
437 |
-
filename=f"trimmed_{file.filename}"
|
438 |
-
)
|
439 |
-
except CouldntDecodeError:
|
440 |
-
logger.warning(f"pydub failed to decode: {file.filename}")
|
441 |
-
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
|
442 |
-
except Exception as e:
|
443 |
-
logger.error(f"Error during trim operation: {e}", exc_info=True)
|
444 |
-
if output_path: cleanup_file(output_path) # Immediate cleanup on error
|
445 |
-
if input_path: cleanup_file(input_path) # Immediate cleanup on error
|
446 |
-
if isinstance(e, HTTPException): raise e
|
447 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}")
|
448 |
-
|
449 |
-
|
450 |
-
@app.post("/concat", tags=["Basic Editing"])
|
451 |
-
async def concatenate_audio(
|
452 |
-
background_tasks: BackgroundTasks,
|
453 |
-
files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
|
454 |
-
output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
|
455 |
-
):
|
456 |
-
"""Concatenates two or more audio files sequentially."""
|
457 |
-
if len(files) < 2:
|
458 |
-
raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
|
459 |
-
|
460 |
-
logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
|
461 |
-
input_paths = []
|
462 |
-
loaded_audios = []
|
463 |
-
output_path = None
|
464 |
-
|
465 |
-
try:
|
466 |
-
combined_audio = AudioSegment.empty()
|
467 |
-
first_filename_base = "combined"
|
468 |
-
|
469 |
-
for i, file in enumerate(files):
|
470 |
-
input_path = await save_upload_file(file, prefix=f"concat_{i}_")
|
471 |
-
input_paths.append(input_path)
|
472 |
-
background_tasks.add_task(cleanup_file, input_path)
|
473 |
-
audio = AudioSegment.from_file(input_path)
|
474 |
-
combined_audio += audio
|
475 |
-
if i == 0: first_filename_base = os.path.splitext(file.filename)[0]
|
476 |
-
logger.info(f"Appended '{file.filename}', current total duration: {len(combined_audio)}ms")
|
477 |
-
|
478 |
-
if len(combined_audio) == 0:
|
479 |
-
raise HTTPException(status_code=500, detail="No audio data after attempting concatenation.")
|
480 |
-
|
481 |
-
output_filename_final = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
|
482 |
-
output_path = os.path.join(TEMP_DIR, f"concat_out_{uuid.uuid4().hex}.{output_format}")
|
483 |
-
|
484 |
-
combined_audio.export(output_path, format=output_format)
|
485 |
-
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
|
486 |
-
|
487 |
-
return FileResponse(
|
488 |
-
path=output_path,
|
489 |
-
media_type=f"audio/{output_format}",
|
490 |
-
filename=output_filename_final
|
491 |
-
)
|
492 |
-
except CouldntDecodeError as e:
|
493 |
-
logger.warning(f"pydub failed to decode one of the concat files: {e}")
|
494 |
-
raise HTTPException(status_code=415, detail=f"Unsupported format or corrupted file among inputs: {e}")
|
495 |
-
except Exception as e:
|
496 |
-
logger.error(f"Error during concat operation: {e}", exc_info=True)
|
497 |
-
if output_path: cleanup_file(output_path)
|
498 |
-
for p in input_paths: cleanup_file(p)
|
499 |
-
if isinstance(e, HTTPException): raise e
|
500 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}")
|
501 |
-
|
502 |
-
@app.post("/volume", tags=["Basic Editing"])
|
503 |
-
async def change_volume(
|
504 |
-
background_tasks: BackgroundTasks,
|
505 |
-
file: UploadFile = File(..., description="Audio file to adjust volume for."),
|
506 |
-
change_db: float = Form(..., description="Volume change in decibels (dB). Positive values increase volume, negative values decrease.")
|
507 |
-
):
|
508 |
-
"""Adjusts the volume of an audio file by a specified decibel amount."""
|
509 |
-
logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
|
510 |
-
input_path = None
|
511 |
-
output_path = None
|
512 |
-
try:
|
513 |
-
input_path = await save_upload_file(file, prefix="volume_in_")
|
514 |
-
background_tasks.add_task(cleanup_file, input_path)
|
515 |
-
|
516 |
-
audio = AudioSegment.from_file(input_path)
|
517 |
-
adjusted_audio = audio + change_db
|
518 |
-
logger.info(f"Volume adjusted by {change_db}dB.")
|
519 |
-
|
520 |
-
original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
|
521 |
-
if not original_format or original_format == "tmp": original_format = "mp3"
|
522 |
-
output_filename_final = f"volume_{change_db}dB_{file.filename}"
|
523 |
-
output_path = os.path.join(TEMP_DIR, f"volume_out_{uuid.uuid4().hex}.{original_format}")
|
524 |
-
|
525 |
-
adjusted_audio.export(output_path, format=original_format)
|
526 |
-
background_tasks.add_task(cleanup_file, output_path)
|
527 |
-
|
528 |
-
return FileResponse(
|
529 |
-
path=output_path,
|
530 |
-
media_type=f"audio/{original_format}",
|
531 |
-
filename=output_filename_final
|
532 |
-
)
|
533 |
-
except CouldntDecodeError:
|
534 |
-
logger.warning(f"pydub failed to decode: {file.filename}")
|
535 |
-
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
|
536 |
-
except Exception as e:
|
537 |
-
logger.error(f"Error during volume operation: {e}", exc_info=True)
|
538 |
-
if output_path: cleanup_file(output_path)
|
539 |
-
if input_path: cleanup_file(input_path)
|
540 |
-
if isinstance(e, HTTPException): raise e
|
541 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}")
|
542 |
-
|
543 |
-
@app.post("/convert", tags=["Basic Editing"])
|
544 |
-
async def convert_format(
|
545 |
-
background_tasks: BackgroundTasks,
|
546 |
-
file: UploadFile = File(..., description="Audio file to convert."),
|
547 |
-
output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
|
548 |
-
):
|
549 |
-
"""Converts an audio file to a different format."""
|
550 |
-
allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a'}
|
551 |
-
if output_format.lower() not in allowed_formats:
|
552 |
-
raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed: {', '.join(allowed_formats)}")
|
553 |
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
input_path = await save_upload_file(file, prefix="convert_in_")
|
559 |
-
background_tasks.add_task(cleanup_file, input_path)
|
560 |
-
|
561 |
-
audio = AudioSegment.from_file(input_path)
|
562 |
-
|
563 |
-
output_format_lower = output_format.lower()
|
564 |
-
filename_base = os.path.splitext(file.filename)[0]
|
565 |
-
output_filename_final = f"{filename_base}_converted.{output_format_lower}"
|
566 |
-
output_path = os.path.join(TEMP_DIR, f"convert_out_{uuid.uuid4().hex}.{output_format_lower}")
|
567 |
-
|
568 |
-
audio.export(output_path, format=output_format_lower)
|
569 |
-
background_tasks.add_task(cleanup_file, output_path)
|
570 |
-
|
571 |
-
return FileResponse(
|
572 |
-
path=output_path,
|
573 |
-
media_type=f"audio/{output_format_lower}",
|
574 |
-
filename=output_filename_final
|
575 |
-
)
|
576 |
-
except CouldntDecodeError:
|
577 |
-
logger.warning(f"pydub failed to decode: {file.filename}")
|
578 |
-
raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
|
579 |
-
except Exception as e:
|
580 |
-
logger.error(f"Error during convert operation: {e}", exc_info=True)
|
581 |
-
if output_path: cleanup_file(output_path)
|
582 |
-
if input_path: cleanup_file(input_path)
|
583 |
-
if isinstance(e, HTTPException): raise e
|
584 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}")
|
585 |
|
|
|
586 |
|
587 |
-
# (Include /enhance and /separate AI endpoints here - same as previous version)
|
588 |
-
# ...
|
589 |
@app.post("/enhance", tags=["AI Editing"])
|
590 |
async def enhance_speech(
|
591 |
background_tasks: BackgroundTasks,
|
592 |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
|
593 |
-
|
594 |
-
|
|
|
595 |
):
|
596 |
-
"""Enhances speech audio using a pre-loaded
|
597 |
-
if
|
598 |
-
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
599 |
-
if model_key not in enhancement_models
|
600 |
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
|
601 |
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
|
602 |
|
603 |
-
|
604 |
-
|
|
|
|
|
|
|
605 |
output_path = None
|
|
|
606 |
try:
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
logger.debug(f"Loading audio for enhancement, target SR: {ENHANCEMENT_SR}")
|
612 |
-
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
613 |
-
if current_sr != ENHANCEMENT_SR: # Should have been resampled, but double check
|
614 |
-
logger.warning(f"Audio SR after loading is {current_sr}, expected {ENHANCEMENT_SR}. Check resampling.")
|
615 |
-
# Depending on model strictness, could raise error or proceed cautiously.
|
616 |
-
# raise HTTPException(status_code=500, detail="Audio resampling failed.")
|
617 |
-
|
618 |
-
# Run inference in a separate thread
|
619 |
logger.info("Submitting enhancement task to background thread...")
|
620 |
-
|
621 |
-
_run_enhancement_sync,
|
622 |
)
|
623 |
logger.info("Enhancement task completed.")
|
624 |
|
625 |
-
# Save the result
|
626 |
-
output_path = save_hf_audio(
|
627 |
background_tasks.add_task(cleanup_file, output_path)
|
628 |
|
629 |
-
output_filename_final = f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
|
630 |
return FileResponse(
|
631 |
path=output_path,
|
632 |
media_type=f"audio/{output_format}",
|
633 |
-
filename=
|
634 |
)
|
635 |
-
|
636 |
except Exception as e:
|
637 |
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
|
638 |
if output_path: cleanup_file(output_path)
|
639 |
-
if input_path: cleanup_file(input_path)
|
640 |
if isinstance(e, HTTPException): raise e
|
641 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
|
642 |
|
@@ -645,96 +390,79 @@ async def enhance_speech(
|
|
645 |
async def separate_sources(
|
646 |
background_tasks: BackgroundTasks,
|
647 |
file: UploadFile = File(..., description="Music audio file to separate into stems."),
|
648 |
-
model_key: str =
|
649 |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
|
650 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
651 |
):
|
652 |
-
"""Separates music into stems
|
653 |
-
if
|
654 |
-
raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
655 |
-
if model_key not in separation_models
|
656 |
logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
|
657 |
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
|
658 |
|
659 |
-
|
|
|
660 |
requested_stems = set(s.lower() for s in stems)
|
661 |
if not requested_stems.issubset(valid_stems):
|
662 |
-
|
663 |
-
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are generally: {', '.join(valid_stems)}")
|
664 |
|
665 |
-
logger.info(f"Separate request: file='{file.filename}',
|
666 |
-
input_path =
|
|
|
667 |
stem_output_paths: Dict[str, str] = {}
|
668 |
-
zip_buffer =
|
669 |
|
670 |
try:
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
# Load audio, ensure correct SR for the model
|
675 |
-
logger.debug(f"Loading audio for separation, target SR: {DEMUCS_SR}")
|
676 |
-
audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
677 |
-
if current_sr != DEMUCS_SR:
|
678 |
-
logger.warning(f"Audio SR after loading is {current_sr}, expected {DEMUCS_SR}. Check resampling.")
|
679 |
-
# raise HTTPException(status_code=500, detail="Audio resampling failed.")
|
680 |
|
681 |
-
# Run inference in a separate thread
|
682 |
logger.info("Submitting separation task to background thread...")
|
683 |
-
|
684 |
-
_run_separation_sync,
|
685 |
)
|
686 |
logger.info("Separation task completed.")
|
687 |
|
688 |
# --- Create ZIP file in memory ---
|
689 |
-
|
690 |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
691 |
-
|
692 |
-
found_stems_count = 0
|
693 |
for stem_name in requested_stems:
|
694 |
-
if stem_name in
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
stem_output_paths[stem_name] = stem_path # Keep track for cleanup
|
704 |
-
background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
|
705 |
-
|
706 |
-
# Add the saved stem file to the ZIP archive
|
707 |
-
archive_name = f"{stem_name}.{output_format}" # Simple name inside zip
|
708 |
zipf.write(stem_path, arcname=archive_name)
|
709 |
logger.info(f"Added '{archive_name}' to ZIP.")
|
710 |
-
found_stems_count += 1
|
711 |
else:
|
712 |
-
|
|
|
713 |
|
714 |
-
|
715 |
-
raise HTTPException(status_code=404, detail="None of the requested stems were found or generated successfully.")
|
716 |
|
717 |
-
|
718 |
-
|
719 |
-
# Return the ZIP file via StreamingResponse
|
720 |
-
zip_filename_download = f"{zip_filename_base}.zip"
|
721 |
-
logger.info(f"Sending ZIP file '{zip_filename_download}'")
|
722 |
return StreamingResponse(
|
723 |
-
zip_buffer,
|
724 |
media_type="application/zip",
|
725 |
-
headers={'Content-Disposition': f'attachment; filename="{
|
726 |
)
|
727 |
-
|
728 |
except Exception as e:
|
729 |
logger.error(f"Error during separation operation: {e}", exc_info=True)
|
730 |
-
# Cleanup temporary stem files if they exist
|
731 |
for path in stem_output_paths.values(): cleanup_file(path)
|
732 |
-
|
733 |
-
# if zip_buffer and not zip_buffer.closed: zip_buffer.close()
|
734 |
-
if input_path: cleanup_file(input_path)
|
735 |
-
|
736 |
if isinstance(e, HTTPException): raise e
|
737 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
|
738 |
|
|
|
|
|
739 |
|
740 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import uuid
|
3 |
import tempfile
|
4 |
import logging
|
5 |
import asyncio
|
6 |
from typing import List, Optional, Dict, Any
|
|
|
|
|
7 |
|
8 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
|
9 |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
10 |
+
import io
|
11 |
+
import zipfile
|
12 |
|
13 |
# --- Basic Editing Imports ---
|
14 |
from pydub import AudioSegment
|
|
|
17 |
# --- AI & Advanced Audio Imports ---
|
18 |
try:
|
19 |
import torch
|
20 |
+
# Transformers only needed if using HF pipelines directly, not for speechbrain/demucs manual loading
|
21 |
+
# from transformers import pipeline
|
|
|
|
|
22 |
import soundfile as sf
|
23 |
import numpy as np
|
24 |
+
import librosa
|
25 |
+
|
26 |
+
# Specific Model Libraries
|
27 |
+
import speechbrain.pretrained
|
28 |
+
import demucs.separate
|
29 |
+
import demucs.apply
|
30 |
+
|
31 |
print("AI and advanced audio libraries loaded.")
|
32 |
except ImportError as e:
|
33 |
+
print(f"Error importing AI/Audio libraries: {e}")
|
34 |
+
print("Ensure torch, soundfile, librosa, speechbrain, demucs are installed.")
|
35 |
print("AI features will be unavailable.")
|
|
|
|
|
36 |
torch = None
|
|
|
37 |
sf = None
|
38 |
np = None
|
39 |
librosa = None
|
40 |
+
speechbrain = None
|
41 |
+
demucs = None
|
42 |
|
43 |
# --- Configuration & Setup ---
|
44 |
TEMP_DIR = tempfile.gettempdir()
|
45 |
os.makedirs(TEMP_DIR, exist_ok=True)
|
46 |
|
|
|
47 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
48 |
logger = logging.getLogger(__name__)
|
49 |
|
50 |
# --- Global Variables for Loaded Models ---
|
51 |
+
# Use consistent keys for storing/retrieving models
|
52 |
+
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
|
53 |
+
# Choose a default Demucs model (htdemucs is good quality)
|
54 |
+
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
enhancement_models: Dict[str, Any] = {}
|
57 |
+
separation_models: Dict[str, Any] = {}
|
58 |
|
59 |
+
# Target sampling rates (confirm from model specifics if necessary)
|
60 |
+
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
|
61 |
+
DEMUCS_SR = 44100 # Demucs default is 44.1kHz
|
62 |
|
63 |
+
# --- Device Selection ---
|
64 |
+
if torch:
|
65 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
66 |
+
logger.info(f"Using device: {DEVICE}")
|
67 |
+
else:
|
68 |
+
DEVICE = "cpu" # Fallback if torch failed import
|
69 |
+
|
70 |
+
# --- Helper Functions (cleanup_file, save_upload_file - same as before) ---
|
71 |
def cleanup_file(file_path: str):
|
|
|
72 |
try:
|
73 |
if file_path and os.path.exists(file_path):
|
74 |
os.remove(file_path)
|
|
|
77 |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
|
78 |
|
79 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
|
|
80 |
_, file_extension = os.path.splitext(upload_file.filename)
|
81 |
+
if not file_extension: file_extension = ".wav"
|
82 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
83 |
try:
|
84 |
with open(temp_file_path, "wb") as buffer:
|
|
|
92 |
finally:
|
93 |
await upload_file.close()
|
94 |
|
95 |
+
# --- Audio Loading/Saving for AI Models ---
|
96 |
+
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
|
97 |
+
"""Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
|
|
|
98 |
try:
|
99 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
100 |
logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
|
101 |
|
102 |
+
if audio.ndim > 1 and audio.shape[0] > 5: # Check if likely stereo (more than 5 channels unlikely)
|
103 |
+
logger.warning(f"Detected {audio.shape[0]} channels, attempting to convert to mono by averaging.")
|
104 |
+
audio = np.mean(audio, axis=0) # Average channels if multi-channel
|
105 |
+
elif audio.ndim > 1:
|
106 |
+
audio = audio[0] # Take first channel if shape is like (1, N)
|
|
|
|
|
|
|
107 |
|
108 |
+
# Convert numpy array to torch tensor
|
109 |
+
audio_tensor = torch.from_numpy(audio).float()
|
110 |
|
111 |
+
# Resample if necessary using librosa then convert back to tensor
|
112 |
if target_sr and orig_sr != target_sr:
|
113 |
+
if librosa is None: raise RuntimeError("Librosa is required for resampling but not installed.")
|
|
|
114 |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
|
115 |
+
# Librosa works on numpy, so convert back temp.
|
116 |
+
audio_np = audio_tensor.numpy()
|
117 |
+
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
|
118 |
+
audio_tensor = torch.from_numpy(resampled_audio_np).float()
|
|
|
119 |
current_sr = target_sr
|
120 |
+
logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
|
121 |
else:
|
122 |
current_sr = orig_sr
|
123 |
|
124 |
+
# Ensure tensor is on the correct device
|
125 |
+
return audio_tensor.to(DEVICE), current_sr
|
126 |
|
127 |
except Exception as e:
|
128 |
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
|
129 |
raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")
|
130 |
|
131 |
+
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
|
132 |
+
"""Saves audio data (Tensor or NumPy array) to a temporary file."""
|
|
|
|
|
|
|
133 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
|
134 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
135 |
try:
|
136 |
+
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
|
137 |
+
|
138 |
+
# Convert tensor to numpy array if needed
|
139 |
+
if isinstance(audio_data, torch.Tensor):
|
140 |
+
logger.debug("Converting output tensor to NumPy array.")
|
141 |
+
# Ensure tensor is on CPU before converting to numpy
|
142 |
+
audio_np = audio_data.detach().cpu().numpy()
|
143 |
+
elif isinstance(audio_data, np.ndarray):
|
144 |
+
audio_np = audio_data
|
145 |
+
else:
|
146 |
+
raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
|
147 |
|
148 |
+
# Ensure data is float32
|
149 |
+
if audio_np.dtype != np.float32:
|
150 |
+
logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
|
151 |
+
audio_np = audio_np.astype(np.float32)
|
152 |
|
153 |
+
# Clip values to avoid potential issues with formats expecting [-1, 1]
|
154 |
+
audio_np = np.clip(audio_np, -1.0, 1.0)
|
155 |
|
156 |
+
# Use soundfile (preferred for wav/flac)
|
157 |
if output_format.lower() in ['wav', 'flac']:
|
158 |
+
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
|
159 |
else:
|
160 |
+
# For lossy formats, use pydub
|
161 |
+
logger.debug(f"Using pydub to export to lossy format: {output_format}")
|
162 |
# Scale float32 [-1, 1] to int16 for pydub
|
163 |
+
audio_int16 = (audio_np * 32767).astype(np.int16)
|
164 |
+
# Create AudioSegment (assuming mono for now)
|
165 |
+
num_channels = 1 if audio_int16.ndim == 1 else audio_int16.shape[0] # Basic channel check
|
166 |
+
if num_channels > 1 : audio_int16=audio_int16[0] # Use first channel if > 1, needs better handling
|
|
|
167 |
segment = AudioSegment(
|
168 |
audio_int16.tobytes(),
|
169 |
frame_rate=sampling_rate,
|
170 |
sample_width=audio_int16.dtype.itemsize,
|
171 |
+
channels=1 # Forcing mono currently
|
172 |
)
|
173 |
segment.export(output_path, format=output_format)
|
174 |
|
|
|
178 |
cleanup_file(output_path)
|
179 |
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
180 |
|
181 |
+
# --- Synchronous AI Inference Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
+
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
|
184 |
+
"""Synchronous wrapper for SpeechBrain enhancement model inference."""
|
185 |
+
if not model: raise ValueError("Enhancement model not loaded")
|
186 |
try:
|
187 |
+
logger.info(f"Running speech enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
|
188 |
+
# SpeechBrain models usually take tensors directly
|
189 |
+
# Add batch dimension if needed (most SB models expect batch)
|
190 |
+
if audio_tensor.ndim == 1:
|
191 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
192 |
+
|
193 |
+
# Move tensor to the same device as the model
|
194 |
+
model_device = next(model.parameters()).device
|
195 |
+
audio_tensor = audio_tensor.to(model_device)
|
196 |
|
197 |
+
with torch.no_grad():
|
198 |
+
# Use enhance_batch for batched input
|
199 |
+
enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
|
200 |
|
201 |
+
# Remove batch dimension from output before returning
|
202 |
+
enhanced_audio = enhanced_tensor.squeeze(0).cpu() # Move back to CPU
|
203 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
204 |
return enhanced_audio
|
205 |
except Exception as e:
|
206 |
+
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
|
207 |
+
raise
|
208 |
|
209 |
+
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
|
210 |
+
"""Synchronous wrapper for Demucs source separation model inference."""
|
211 |
+
if not model: raise ValueError("Separation model not loaded")
|
212 |
+
if not demucs: raise RuntimeError("Demucs library not available")
|
213 |
+
try:
|
214 |
+
logger.info(f"Running source separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
|
215 |
|
216 |
+
# Demucs expects audio as (batch, channels, samples)
|
217 |
+
# Ensure input tensor is on the correct device
|
218 |
+
model_device = next(model.parameters()).device
|
219 |
+
audio_tensor = audio_tensor.to(model_device)
|
220 |
|
221 |
+
# Add batch and channel dimensions if mono
|
222 |
+
if audio_tensor.ndim == 1:
|
223 |
+
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
|
224 |
+
elif audio_tensor.ndim == 2: # Should not happen often if load_audio ensures mono tensor
|
225 |
+
logger.warning("Input tensor has 2 dims, assuming (batch, samples), adding channel dim.")
|
226 |
+
audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
|
227 |
|
228 |
+
# Ensure correct number of channels expected by the model (usually 2)
|
229 |
+
if audio_tensor.shape[1] != model.audio_channels:
|
230 |
+
logger.warning(f"Model expects {model.audio_channels} channels, input has {audio_tensor.shape[1]}. Repeating mono channel.")
|
231 |
+
audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1) # Repeat mono to match expected channels
|
232 |
+
|
233 |
+
|
234 |
+
logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
with torch.no_grad():
|
237 |
+
# Use demucs.apply.apply_model which handles chunking etc.
|
238 |
+
# requires ref = audio_tensor.mean(0) # Average channels for reference
|
239 |
+
# sources = demucs.apply.apply_model(model, audio_tensor[0], device=model_device, shifts=1, split=True, overlap=0.25)[0] # Process first batch item
|
240 |
+
|
241 |
+
# OR direct model call if simpler:
|
242 |
+
sources = model(audio_tensor)[0] # Output shape (stems, channels, samples) - remove batch dim [0]
|
243 |
+
|
244 |
+
logger.debug(f"Raw separated sources tensor shape: {sources.shape}") # Should be (num_stems, channels, samples)
|
245 |
+
|
246 |
+
# Map stems based on the model's sources list
|
247 |
+
# Default for htdemucs: drums, bass, other, vocals
|
248 |
+
stem_map = {name: sources[i] for i, name in enumerate(model.sources)}
|
249 |
+
|
250 |
+
# Convert back to mono for simplicity (average channels) and move to CPU
|
251 |
+
output_stems = {}
|
252 |
+
for name, data in stem_map.items():
|
253 |
+
output_stems[name] = data.mean(dim=0).detach().cpu() # Average channels, detach, move to CPU
|
254 |
+
|
255 |
+
logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
|
256 |
+
return output_stems
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
except Exception as e:
|
259 |
+
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
|
260 |
raise
|
261 |
|
262 |
# --- Model Loading Function ---
|
|
|
|
|
|
|
|
|
263 |
def load_hf_models():
|
264 |
+
"""Loads AI models at startup using correct libraries."""
|
|
|
|
|
|
|
|
|
265 |
global enhancement_models, separation_models
|
266 |
+
if torch is None or speechbrain is None or demucs is None:
|
267 |
+
logger.error("Core AI libraries (torch, speechbrain, demucs) not available. Skipping model loading.")
|
268 |
+
return
|
269 |
|
270 |
+
# --- Load Enhancement Model (SpeechBrain) ---
|
271 |
+
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
|
272 |
try:
|
273 |
+
logger.info(f"Loading enhancement model: {enhancement_model_hparams} (using SpeechBrain)...")
|
274 |
+
# Ensure SpeechBrain downloads to a writable location if needed (optional)
|
275 |
+
# savedir = os.path.join(TEMP_DIR, "speechbrain_models")
|
276 |
+
# os.makedirs(savedir, exist_ok=True)
|
277 |
+
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
|
278 |
+
source=enhancement_model_hparams,
|
279 |
+
# savedir=savedir, # Specify download dir if needed
|
280 |
+
run_opts={"device": DEVICE} # Pass device option
|
281 |
+
)
|
282 |
+
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer # Store with consistent key
|
283 |
+
logger.info(f"Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {DEVICE}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
except Exception as e:
|
285 |
+
logger.error(f"Failed to load enhancement model '{enhancement_model_hparams}': {e}", exc_info=True)
|
|
|
286 |
|
287 |
# --- Load Separation Model (Demucs) ---
|
288 |
+
# Using a standard pretrained model name from the demucs package
|
289 |
+
separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs" or "mdx_extra_q"
|
290 |
try:
|
291 |
+
logger.info(f"Loading separation model: {separation_model_name} (using Demucs package)...")
|
292 |
+
# This automatically handles downloading the model checkpoint
|
293 |
+
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
|
294 |
+
separation_models[SEPARATION_MODEL_KEY] = separator # Store with consistent key
|
295 |
+
logger.info(f"Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {DEVICE}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
except Exception as e:
|
297 |
+
logger.error(f"Failed to load separation model '{separation_model_name}': {e}", exc_info=True)
|
298 |
+
logger.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
|
299 |
|
300 |
|
301 |
+
# --- FastAPI App ---
|
302 |
app = FastAPI(
|
303 |
title="AI Audio Editor API",
|
304 |
+
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries (torch, speechbrain, demucs).",
|
305 |
+
version="2.1.0", # Incremented version
|
306 |
)
|
307 |
|
308 |
@app.on_event("startup")
|
309 |
async def startup_event():
|
310 |
+
logger.info("Application startup: Loading AI models...")
|
311 |
+
# Run blocking model load in thread
|
312 |
await asyncio.to_thread(load_hf_models)
|
313 |
+
logger.info("Model loading process finished (check logs for success/failure).")
|
|
|
314 |
|
315 |
# --- API Endpoints ---
|
316 |
+
|
|
|
317 |
@app.get("/", tags=["General"])
|
318 |
def read_root():
|
319 |
+
# ... (root endpoint remains the same) ...
|
320 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
321 |
ai_features = []
|
322 |
+
if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
|
323 |
+
if separation_models: ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY})")
|
|
|
324 |
|
325 |
return {
|
326 |
"message": "Welcome to the AI Audio Editor API.",
|
327 |
"basic_features": features,
|
328 |
+
"ai_features": ai_features if ai_features else "None available (check startup logs)",
|
329 |
+
"notes": "Requires FFmpeg. AI features require specific models loaded at startup."
|
330 |
}
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
+
# --- Basic Editing Endpoints ---
|
334 |
+
# (Add /trim, /concat, /volume, /convert endpoints here - same logic as before)
|
335 |
+
# Make sure they use the updated cleanup_file and save_upload_file helpers.
|
336 |
+
# ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
+
# --- AI Endpoints (Corrected) ---
|
339 |
|
|
|
|
|
340 |
@app.post("/enhance", tags=["AI Editing"])
|
341 |
async def enhance_speech(
|
342 |
background_tasks: BackgroundTasks,
|
343 |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
|
344 |
+
# Model ID is less relevant now if only one is loaded, but keep for future flexibility
|
345 |
+
model_key: str = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use."),
|
346 |
+
output_format: str = Form("wav", description="Output format (wav, flac recommended).")
|
347 |
):
|
348 |
+
"""Enhances speech audio using a pre-loaded SpeechBrain model."""
|
349 |
+
if torch is None or speechbrain is None:
|
350 |
+
raise HTTPException(status_code=501, detail="AI processing libraries (torch, speechbrain) not available.")
|
351 |
+
if model_key not in enhancement_models:
|
352 |
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
|
353 |
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
|
354 |
|
355 |
+
loaded_model = enhancement_models[model_key]
|
356 |
+
|
357 |
+
logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
|
358 |
+
input_path = await save_upload_file(file, prefix="enhance_in_")
|
359 |
+
background_tasks.add_task(cleanup_file, input_path)
|
360 |
output_path = None
|
361 |
+
|
362 |
try:
|
363 |
+
# Load audio as tensor, ensure correct SR
|
364 |
+
# SpeechBrain Sepformer expects 16kHz
|
365 |
+
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
366 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
logger.info("Submitting enhancement task to background thread...")
|
368 |
+
enhanced_audio_tensor = await asyncio.to_thread(
|
369 |
+
_run_enhancement_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
|
370 |
)
|
371 |
logger.info("Enhancement task completed.")
|
372 |
|
373 |
+
# Save the result (tensor output from enhancer)
|
374 |
+
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format) # Save at model's SR
|
375 |
background_tasks.add_task(cleanup_file, output_path)
|
376 |
|
|
|
377 |
return FileResponse(
|
378 |
path=output_path,
|
379 |
media_type=f"audio/{output_format}",
|
380 |
+
filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
|
381 |
)
|
|
|
382 |
except Exception as e:
|
383 |
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
|
384 |
if output_path: cleanup_file(output_path)
|
|
|
385 |
if isinstance(e, HTTPException): raise e
|
386 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
|
387 |
|
|
|
390 |
async def separate_sources(
|
391 |
background_tasks: BackgroundTasks,
|
392 |
file: UploadFile = File(..., description="Music audio file to separate into stems."),
|
393 |
+
model_key: str = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use."),
|
394 |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
|
395 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
396 |
):
|
397 |
+
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
|
398 |
+
if torch is None or demucs is None:
|
399 |
+
raise HTTPException(status_code=501, detail="AI processing libraries (torch, demucs) not available.")
|
400 |
+
if model_key not in separation_models:
|
401 |
logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
|
402 |
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
|
403 |
|
404 |
+
loaded_model = separation_models[model_key]
|
405 |
+
valid_stems = set(loaded_model.sources) # Get stems directly from loaded model
|
406 |
requested_stems = set(s.lower() for s in stems)
|
407 |
if not requested_stems.issubset(valid_stems):
|
408 |
+
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Model '{model_key}' provides: {', '.join(valid_stems)}")
|
|
|
409 |
|
410 |
+
logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
|
411 |
+
input_path = await save_upload_file(file, prefix="separate_in_")
|
412 |
+
background_tasks.add_task(cleanup_file, input_path)
|
413 |
stem_output_paths: Dict[str, str] = {}
|
414 |
+
zip_buffer = None
|
415 |
|
416 |
try:
|
417 |
+
# Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
|
418 |
+
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
|
|
|
420 |
logger.info("Submitting separation task to background thread...")
|
421 |
+
all_separated_stems_tensors = await asyncio.to_thread(
|
422 |
+
_run_separation_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
|
423 |
)
|
424 |
logger.info("Separation task completed.")
|
425 |
|
426 |
# --- Create ZIP file in memory ---
|
427 |
+
zip_buffer = io.BytesIO()
|
428 |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
429 |
+
# Save only the requested stems
|
|
|
430 |
for stem_name in requested_stems:
|
431 |
+
if stem_name in all_separated_stems_tensors:
|
432 |
+
stem_tensor = all_separated_stems_tensors[stem_name]
|
433 |
+
# Save stem temporarily (save_hf_audio handles tensor)
|
434 |
+
# Use the model's native sampling rate for output
|
435 |
+
stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
|
436 |
+
stem_output_paths[stem_name] = stem_path
|
437 |
+
background_tasks.add_task(cleanup_file, stem_path)
|
438 |
+
|
439 |
+
archive_name = f"{stem_name}_{os.path.splitext(file.filename)[0]}.{output_format}"
|
|
|
|
|
|
|
|
|
|
|
440 |
zipf.write(stem_path, arcname=archive_name)
|
441 |
logger.info(f"Added '{archive_name}' to ZIP.")
|
|
|
442 |
else:
|
443 |
+
# This case should be prevented by the earlier validation
|
444 |
+
logger.warning(f"Requested stem '{stem_name}' not found in model output (should not happen).")
|
445 |
|
446 |
+
zip_buffer.seek(0)
|
|
|
447 |
|
448 |
+
zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
|
|
|
|
|
|
|
|
|
449 |
return StreamingResponse(
|
450 |
+
zip_buffer,
|
451 |
media_type="application/zip",
|
452 |
+
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
|
453 |
)
|
|
|
454 |
except Exception as e:
|
455 |
logger.error(f"Error during separation operation: {e}", exc_info=True)
|
|
|
456 |
for path in stem_output_paths.values(): cleanup_file(path)
|
457 |
+
if zip_buffer: zip_buffer.close()
|
|
|
|
|
|
|
458 |
if isinstance(e, HTTPException): raise e
|
459 |
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
|
460 |
|
461 |
+
# --- Add back the basic editing endpoints (/trim, /concat, /volume, /convert) here ---
|
462 |
+
# ... (Remember to include them) ...
|
463 |
|
464 |
+
# --- How to Run ---
|
465 |
+
# 1. Ensure FFmpeg is installed.
|
466 |
+
# 2. Save code as `app.py`. Create/update `requirements.txt`.
|
467 |
+
# 3. Install: `pip install -r requirements.txt` (May take significant time/space!)
|
468 |
+
# 4. Run: `uvicorn app:app --reload --host 0.0.0.0`
|