Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import tempfile
|
|
5 |
import logging
|
6 |
import asyncio
|
7 |
from typing import List, Optional, Dict, Any
|
|
|
8 |
|
9 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
|
10 |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
@@ -16,344 +17,292 @@ from pydub import AudioSegment
|
|
16 |
from pydub.exceptions import CouldntDecodeError
|
17 |
|
18 |
# --- AI & Advanced Audio Imports ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
try:
|
|
|
20 |
import torch
|
21 |
-
|
22 |
-
# from transformers import pipeline
|
23 |
import soundfile as sf
|
|
|
24 |
import numpy as np
|
|
|
25 |
import librosa
|
26 |
-
|
27 |
-
# Specific Model Libraries
|
28 |
import speechbrain.pretrained
|
|
|
29 |
import demucs.separate
|
30 |
import demucs.apply
|
31 |
-
|
32 |
-
|
33 |
except ImportError as e:
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
torch = None
|
38 |
sf = None
|
39 |
np = None
|
40 |
librosa = None
|
41 |
speechbrain = None
|
42 |
demucs = None
|
|
|
43 |
|
44 |
# --- Configuration & Setup ---
|
45 |
TEMP_DIR = tempfile.gettempdir()
|
46 |
os.makedirs(TEMP_DIR, exist_ok=True)
|
47 |
|
48 |
-
logging
|
49 |
-
logger = logging.getLogger(__name__)
|
50 |
|
51 |
# --- Global Variables for Loaded Models ---
|
52 |
-
# Use consistent keys for storing/retrieving models
|
53 |
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
|
54 |
-
|
55 |
-
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
|
56 |
|
57 |
enhancement_models: Dict[str, Any] = {}
|
58 |
separation_models: Dict[str, Any] = {}
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
DEMUCS_SR = 44100 # Demucs default is 44.1kHz
|
63 |
|
64 |
# --- Device Selection ---
|
65 |
if torch:
|
66 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
67 |
-
|
68 |
else:
|
69 |
-
DEVICE = "cpu"
|
|
|
70 |
|
71 |
-
# --- Helper Functions ---
|
72 |
|
|
|
73 |
def cleanup_file(file_path: str):
|
74 |
"""Safely remove a file."""
|
75 |
try:
|
76 |
if file_path and os.path.exists(file_path):
|
77 |
os.remove(file_path)
|
78 |
-
logger.info(f"Cleaned up temporary file: {file_path}")
|
79 |
except Exception as e:
|
80 |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
|
81 |
|
82 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
83 |
"""Saves an uploaded file to a temporary location and returns the path."""
|
84 |
_, file_extension = os.path.splitext(upload_file.filename)
|
85 |
-
# Default to .wav if no extension, as it's widely compatible for loading
|
86 |
if not file_extension: file_extension = ".wav"
|
87 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
88 |
try:
|
89 |
with open(temp_file_path, "wb") as buffer:
|
90 |
-
|
91 |
-
while content := await upload_file.read(1024 * 1024): # 1MB chunks
|
92 |
-
buffer.write(content)
|
93 |
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
|
94 |
return temp_file_path
|
95 |
except Exception as e:
|
96 |
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
|
97 |
-
cleanup_file(temp_file_path)
|
98 |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
|
99 |
finally:
|
100 |
-
await upload_file.close()
|
101 |
|
102 |
-
# --- Audio Loading/Saving
|
103 |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
|
104 |
"""Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
|
|
|
105 |
try:
|
106 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
107 |
-
logger.
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
audio = audio[0
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
audio = np.mean(audio, axis=1)
|
117 |
-
|
118 |
-
# Convert numpy array to torch tensor
|
119 |
-
audio_tensor = torch.from_numpy(audio).float()
|
120 |
|
121 |
-
|
122 |
if target_sr and orig_sr != target_sr:
|
123 |
-
if librosa is None: raise RuntimeError("Librosa
|
124 |
-
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
|
125 |
-
# Librosa works on numpy
|
126 |
audio_np = audio_tensor.numpy()
|
127 |
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
|
128 |
audio_tensor = torch.from_numpy(resampled_audio_np).float()
|
129 |
current_sr = target_sr
|
130 |
-
logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
|
131 |
else:
|
132 |
current_sr = orig_sr
|
133 |
-
|
134 |
-
# Ensure tensor is on the correct device
|
135 |
return audio_tensor.to(DEVICE), current_sr
|
136 |
-
|
137 |
except Exception as e:
|
138 |
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
|
139 |
-
# Clean up the potentially corrupted saved file if loading failed
|
140 |
cleanup_file(file_path)
|
141 |
-
raise HTTPException(status_code=415, detail=f"Could not load
|
|
|
142 |
|
143 |
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
|
144 |
"""Saves audio data (Tensor or NumPy array) to a temporary file."""
|
|
|
145 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
|
146 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
147 |
try:
|
148 |
-
logger.
|
149 |
-
|
150 |
-
# Convert tensor to numpy array if needed
|
151 |
if isinstance(audio_data, torch.Tensor):
|
152 |
-
logger.debug("Converting output tensor to NumPy array.")
|
153 |
-
# Ensure tensor is on CPU before converting to numpy
|
154 |
audio_np = audio_data.detach().cpu().numpy()
|
155 |
elif isinstance(audio_data, np.ndarray):
|
156 |
audio_np = audio_data
|
157 |
else:
|
158 |
-
raise TypeError(f"Unsupported audio data type
|
159 |
|
160 |
-
|
161 |
-
if audio_np.dtype != np.float32:
|
162 |
-
logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
|
163 |
-
audio_np = audio_np.astype(np.float32)
|
164 |
-
|
165 |
-
# Clip values to avoid potential issues with formats expecting [-1, 1]
|
166 |
audio_np = np.clip(audio_np, -1.0, 1.0)
|
167 |
|
168 |
-
# Use soundfile (preferred for wav/flac)
|
169 |
if output_format.lower() in ['wav', 'flac']:
|
170 |
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
|
171 |
else:
|
172 |
-
|
173 |
-
|
174 |
-
# Scale float32 [-1, 1] to int16 for pydub
|
175 |
-
# Ensure audio_np is 1D (mono) before scaling and converting
|
176 |
-
if audio_np.ndim > 1:
|
177 |
-
logger.warning(f"Audio data has {audio_np.ndim} dimensions, taking first dimension for pydub export.")
|
178 |
-
audio_np_mono = audio_np[0] if audio_np.shape[0] < audio_np.shape[1] else audio_np[:, 0] # Basic mono conversion attempt
|
179 |
-
else:
|
180 |
-
audio_np_mono = audio_np
|
181 |
-
|
182 |
audio_int16 = (audio_np_mono * 32767).astype(np.int16)
|
183 |
-
segment = AudioSegment(
|
184 |
-
audio_int16.tobytes(),
|
185 |
-
frame_rate=sampling_rate,
|
186 |
-
sample_width=audio_int16.dtype.itemsize,
|
187 |
-
channels=1 # Assuming mono
|
188 |
-
)
|
189 |
segment.export(output_path, format=output_format)
|
190 |
-
|
191 |
return output_path
|
192 |
except Exception as e:
|
193 |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
|
194 |
-
cleanup_file(output_path)
|
195 |
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
196 |
|
197 |
-
|
198 |
-
# --- Pydub Loading (for basic edits) ---
|
199 |
def load_audio_pydub(file_path: str) -> AudioSegment:
|
200 |
-
|
201 |
try:
|
202 |
audio = AudioSegment.from_file(file_path)
|
203 |
-
logger.info(f"Loaded audio using pydub from: {file_path}")
|
204 |
return audio
|
205 |
-
except CouldntDecodeError:
|
206 |
-
|
207 |
-
raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}")
|
208 |
-
except FileNotFoundError:
|
209 |
-
logger.error(f"Audio file not found after saving (pydub): {file_path}")
|
210 |
-
raise HTTPException(status_code=500, detail="Internal error: Audio file disappeared.")
|
211 |
-
except Exception as e:
|
212 |
-
logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True)
|
213 |
-
raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}")
|
214 |
|
215 |
def export_audio_pydub(audio: AudioSegment, format: str) -> str:
|
216 |
-
|
217 |
output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
|
218 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
219 |
try:
|
220 |
-
logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}")
|
221 |
audio.export(output_path, format=format.lower())
|
222 |
return output_path
|
223 |
-
except Exception as e:
|
224 |
-
logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True)
|
225 |
-
cleanup_file(output_path) # Cleanup if export failed
|
226 |
-
raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.")
|
227 |
|
228 |
|
229 |
-
# --- Synchronous AI Inference Functions ---
|
230 |
-
|
231 |
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
|
232 |
-
|
233 |
if not model: raise ValueError("Enhancement model not loaded")
|
234 |
try:
|
235 |
-
logger.info(f"Running
|
236 |
-
|
237 |
-
if audio_tensor.
|
238 |
-
|
239 |
-
|
240 |
-
# Move tensor to the same device as the model
|
241 |
-
model_device = next(model.parameters()).device # Check model's current device
|
242 |
-
if audio_tensor.device != model_device:
|
243 |
-
audio_tensor = audio_tensor.to(model_device)
|
244 |
-
|
245 |
with torch.no_grad():
|
246 |
enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
|
247 |
-
|
248 |
-
# Remove batch dimension from output before returning, move back to CPU
|
249 |
enhanced_audio = enhanced_tensor.squeeze(0).cpu()
|
250 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
251 |
return enhanced_audio
|
252 |
-
except Exception as e:
|
253 |
-
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
|
254 |
-
raise
|
255 |
|
256 |
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
|
257 |
-
|
258 |
if not model: raise ValueError("Separation model not loaded")
|
259 |
-
if not demucs: raise RuntimeError("Demucs library
|
260 |
try:
|
261 |
-
logger.info(f"Running
|
262 |
-
|
263 |
-
# Move tensor to the same device as the model
|
264 |
model_device = next(model.parameters()).device
|
265 |
-
if audio_tensor.device != model_device:
|
266 |
-
|
267 |
-
|
268 |
-
# Add batch and channel dimensions if mono (expects batch, channels, samples)
|
269 |
-
if audio_tensor.ndim == 1:
|
270 |
-
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
|
271 |
-
elif audio_tensor.ndim == 2: # Should be rare if loader works
|
272 |
-
audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
|
273 |
-
|
274 |
-
# Repeat channel if model expects stereo but input is mono
|
275 |
if audio_tensor.shape[1] != model.audio_channels:
|
276 |
-
if audio_tensor.shape[1] == 1:
|
277 |
-
|
278 |
-
audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
|
279 |
-
else:
|
280 |
-
# Cannot automatically handle other channel mismatches
|
281 |
-
raise ValueError(f"Input audio has {audio_tensor.shape[1]} channels, but Demucs model expects {model.audio_channels}.")
|
282 |
-
|
283 |
-
logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
|
284 |
-
|
285 |
with torch.no_grad():
|
286 |
-
|
287 |
-
# apply_model expects a tensor of shape (channels, samples)
|
288 |
-
# We process one batch item at a time if needed, but typically process the whole file
|
289 |
-
audio_to_process = audio_tensor.squeeze(0) # Remove batch dim -> (channels, samples)
|
290 |
out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25)
|
291 |
-
# Output shape (stems, channels, samples)
|
292 |
-
|
293 |
-
logger.debug(f"Raw separated sources tensor shape: {out.shape}")
|
294 |
-
|
295 |
-
# Map stems based on the model's sources list
|
296 |
stem_map = {name: out[i] for i, name in enumerate(model.sources)}
|
297 |
-
|
298 |
-
|
299 |
-
output_stems = {}
|
300 |
-
for name, data in stem_map.items():
|
301 |
-
# Average channels, detach, move to CPU
|
302 |
-
output_stems[name] = data.mean(dim=0).detach().cpu()
|
303 |
-
|
304 |
-
logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
|
305 |
return output_stems
|
|
|
306 |
|
307 |
-
except Exception as e:
|
308 |
-
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
|
309 |
-
raise
|
310 |
|
311 |
-
# --- Model Loading Function ---
|
312 |
def load_hf_models():
|
313 |
"""Loads AI models at startup using correct libraries."""
|
|
|
|
|
|
|
|
|
314 |
global enhancement_models, separation_models
|
315 |
-
if
|
316 |
-
|
317 |
return
|
318 |
|
319 |
-
# --- Load Enhancement Model
|
320 |
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
|
|
|
321 |
try:
|
322 |
-
|
|
|
323 |
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
|
324 |
source=enhancement_model_hparams,
|
325 |
run_opts={"device": DEVICE}
|
326 |
)
|
|
|
|
|
327 |
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
|
328 |
-
|
329 |
except Exception as e:
|
330 |
-
|
|
|
|
|
|
|
331 |
|
332 |
-
# --- Load Separation Model
|
333 |
separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs"
|
|
|
334 |
try:
|
335 |
-
|
|
|
336 |
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
|
|
|
337 |
separation_models[SEPARATION_MODEL_KEY] = separator
|
338 |
-
|
339 |
-
|
340 |
except Exception as e:
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
|
345 |
# --- FastAPI App ---
|
346 |
app = FastAPI(
|
347 |
title="AI Audio Editor API",
|
348 |
-
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries
|
349 |
-
version="2.1.
|
350 |
)
|
351 |
|
352 |
@app.on_event("startup")
|
353 |
async def startup_event():
|
354 |
-
logger
|
355 |
-
|
356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
# --- API Endpoints ---
|
359 |
|
@@ -362,185 +311,98 @@ def read_root():
|
|
362 |
"""Root endpoint providing a welcome message and available features."""
|
363 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
364 |
ai_features = []
|
|
|
365 |
if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
|
366 |
-
if separation_models:
|
|
|
|
|
|
|
367 |
|
368 |
return {
|
369 |
"message": "Welcome to the AI Audio Editor API.",
|
|
|
|
|
|
|
370 |
"basic_features": features,
|
371 |
"ai_features": ai_features if ai_features else "None available (check startup logs)",
|
372 |
-
"notes": "Requires FFmpeg. AI features require
|
373 |
}
|
374 |
|
375 |
-
# --- Basic Editing Endpoints ---
|
376 |
|
|
|
|
|
377 |
@app.post("/trim", tags=["Basic Editing"])
|
378 |
-
async def trim_audio(
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
end_ms: int = Form(..., description="End time in milliseconds.")
|
383 |
-
):
|
384 |
-
"""Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub."""
|
385 |
-
if start_ms < 0 or end_ms <= start_ms:
|
386 |
-
raise HTTPException(status_code=422, detail="Invalid start/end times. Ensure start_ms >= 0 and end_ms > start_ms.")
|
387 |
-
|
388 |
-
logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
|
389 |
-
input_path = await save_upload_file(file, prefix="trim_in_")
|
390 |
-
background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
|
391 |
-
output_path = None
|
392 |
-
|
393 |
try:
|
394 |
audio = load_audio_pydub(input_path)
|
395 |
trimmed_audio = audio[start_ms:end_ms]
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
output_path = export_audio_pydub(trimmed_audio, original_format)
|
402 |
-
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
|
403 |
-
|
404 |
-
output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}"
|
405 |
-
|
406 |
-
return FileResponse(
|
407 |
-
path=output_path,
|
408 |
-
media_type=f"audio/{original_format}",
|
409 |
-
filename=output_filename
|
410 |
-
)
|
411 |
except Exception as e:
|
412 |
-
logger.error(f"Error during trim operation: {e}", exc_info=True)
|
413 |
if output_path: cleanup_file(output_path)
|
414 |
-
if isinstance(e, HTTPException): raise e
|
415 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}")
|
416 |
|
417 |
@app.post("/concat", tags=["Basic Editing"])
|
418 |
-
async def concatenate_audio(
|
419 |
-
|
420 |
-
|
421 |
-
output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
|
422 |
-
):
|
423 |
-
"""Concatenates two or more audio files sequentially using Pydub."""
|
424 |
-
if len(files) < 2:
|
425 |
-
raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
|
426 |
-
|
427 |
-
logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
|
428 |
-
input_paths = []
|
429 |
-
loaded_audios = []
|
430 |
-
output_path = None
|
431 |
-
|
432 |
try:
|
|
|
433 |
for file in files:
|
434 |
-
|
435 |
-
input_paths.append(
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
if not loaded_audios: raise HTTPException(status_code=500, detail="No audio segments were loaded successfully.")
|
441 |
-
|
442 |
-
combined_audio = loaded_audios[0]
|
443 |
-
logger.info(f"Starting concatenation with first segment ({len(combined_audio)}ms)")
|
444 |
-
for i in range(1, len(loaded_audios)):
|
445 |
-
logger.info(f"Adding segment {i+1} ({len(loaded_audios[i])}ms)")
|
446 |
-
combined_audio += loaded_audios[i]
|
447 |
-
|
448 |
-
logger.info(f"Concatenated audio length: {len(combined_audio)}ms")
|
449 |
-
|
450 |
-
output_path = export_audio_pydub(combined_audio, output_format)
|
451 |
background_tasks.add_task(cleanup_file, output_path)
|
452 |
-
|
453 |
-
|
454 |
-
output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
|
455 |
-
|
456 |
-
return FileResponse(
|
457 |
-
path=output_path,
|
458 |
-
media_type=f"audio/{output_format}",
|
459 |
-
filename=output_filename
|
460 |
-
)
|
461 |
except Exception as e:
|
462 |
-
|
463 |
-
# Cleanup intermediate files if error occurs
|
464 |
-
for path in input_paths: cleanup_file(path)
|
465 |
if output_path: cleanup_file(output_path)
|
466 |
-
if isinstance(e, HTTPException): raise e
|
467 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}")
|
468 |
|
469 |
@app.post("/volume", tags=["Basic Editing"])
|
470 |
-
async def change_volume(
|
471 |
-
|
472 |
-
|
473 |
-
change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.")
|
474 |
-
):
|
475 |
-
"""Adjusts the volume of an audio file by a specified decibel amount using Pydub."""
|
476 |
-
logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
|
477 |
-
input_path = await save_upload_file(file, prefix="volume_in_")
|
478 |
-
background_tasks.add_task(cleanup_file, input_path)
|
479 |
-
output_path = None
|
480 |
-
|
481 |
try:
|
482 |
audio = load_audio_pydub(input_path)
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
|
487 |
-
if not original_format or original_format == "tmp": original_format = "mp3"
|
488 |
-
|
489 |
-
output_path = export_audio_pydub(adjusted_audio, original_format)
|
490 |
background_tasks.add_task(cleanup_file, output_path)
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
return FileResponse(
|
495 |
-
path=output_path,
|
496 |
-
media_type=f"audio/{original_format}",
|
497 |
-
filename=output_filename
|
498 |
-
)
|
499 |
except Exception as e:
|
500 |
-
logger.error(f"Error during volume operation: {e}", exc_info=True)
|
501 |
if output_path: cleanup_file(output_path)
|
502 |
-
if isinstance(e, HTTPException): raise e
|
503 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}")
|
504 |
|
505 |
@app.post("/convert", tags=["Basic Editing"])
|
506 |
-
async def convert_format(
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
)
|
511 |
-
"""Converts an audio file to a different format using Pydub."""
|
512 |
-
allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus'} # Common formats
|
513 |
-
if output_format.lower() not in allowed_formats:
|
514 |
-
raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed formats: {', '.join(allowed_formats)}")
|
515 |
-
|
516 |
-
logger.info(f"Convert request: file='{file.filename}', output_format='{output_format}'")
|
517 |
-
input_path = await save_upload_file(file, prefix="convert_in_")
|
518 |
-
background_tasks.add_task(cleanup_file, input_path)
|
519 |
-
output_path = None
|
520 |
-
|
521 |
try:
|
522 |
-
# Load using pydub, which handles many input formats
|
523 |
audio = load_audio_pydub(input_path)
|
524 |
-
|
525 |
output_path = export_audio_pydub(audio, output_format.lower())
|
526 |
background_tasks.add_task(cleanup_file, output_path)
|
527 |
-
|
528 |
-
|
529 |
-
output_filename = f"{filename_base}_converted.{output_format.lower()}"
|
530 |
-
|
531 |
-
return FileResponse(
|
532 |
-
path=output_path,
|
533 |
-
media_type=f"audio/{output_format.lower()}", # Media type might need refinement for opus/aac/m4a
|
534 |
-
filename=output_filename
|
535 |
-
)
|
536 |
except Exception as e:
|
537 |
-
logger.error(f"Error during convert operation: {e}", exc_info=True)
|
538 |
if output_path: cleanup_file(output_path)
|
539 |
-
if isinstance(e, HTTPException): raise e
|
540 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}")
|
541 |
|
542 |
|
543 |
-
# --- AI Endpoints (
|
544 |
|
545 |
@app.post("/enhance", tags=["AI Editing"])
|
546 |
async def enhance_speech(
|
@@ -550,44 +412,31 @@ async def enhance_speech(
|
|
550 |
output_format: str = Form("wav", description="Output format (wav, flac recommended).")
|
551 |
):
|
552 |
"""Enhances speech audio using a pre-loaded SpeechBrain model."""
|
553 |
-
if
|
554 |
-
raise HTTPException(status_code=501, detail="AI processing libraries (torch, speechbrain) not available.")
|
555 |
if model_key not in enhancement_models:
|
556 |
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
|
557 |
-
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
|
558 |
|
559 |
loaded_model = enhancement_models[model_key]
|
560 |
-
|
561 |
logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
|
562 |
input_path = await save_upload_file(file, prefix="enhance_in_")
|
563 |
background_tasks.add_task(cleanup_file, input_path)
|
564 |
output_path = None
|
565 |
-
|
566 |
try:
|
567 |
-
# Load audio as tensor, ensure correct SR (16kHz)
|
568 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
569 |
-
|
570 |
logger.info("Submitting enhancement task to background thread...")
|
571 |
enhanced_audio_tensor = await asyncio.to_thread(
|
572 |
_run_enhancement_sync, loaded_model, audio_tensor, current_sr
|
573 |
)
|
574 |
logger.info("Enhancement task completed.")
|
575 |
-
|
576 |
-
# Save the result (tensor output from enhancer at 16kHz)
|
577 |
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
|
578 |
background_tasks.add_task(cleanup_file, output_path)
|
579 |
-
|
580 |
output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
|
581 |
-
return FileResponse(
|
582 |
-
path=output_path,
|
583 |
-
media_type=f"audio/{output_format}",
|
584 |
-
filename=output_filename
|
585 |
-
)
|
586 |
except Exception as e:
|
587 |
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
|
588 |
-
if output_path: cleanup_file(output_path)
|
589 |
-
if isinstance(e, HTTPException): raise e
|
590 |
-
else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
|
591 |
|
592 |
|
593 |
@app.post("/separate", tags=["AI Editing"])
|
@@ -599,55 +448,43 @@ async def separate_sources(
|
|
599 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
600 |
):
|
601 |
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
|
602 |
-
if
|
603 |
-
raise HTTPException(status_code=501, detail="AI processing libraries (torch, demucs) not available.")
|
604 |
if model_key not in separation_models:
|
605 |
logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
|
606 |
-
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
|
607 |
|
608 |
loaded_model = separation_models[model_key]
|
609 |
-
valid_stems = set(loaded_model.sources)
|
610 |
requested_stems = set(s.lower() for s in stems)
|
611 |
if not requested_stems.issubset(valid_stems):
|
612 |
-
raise HTTPException(
|
613 |
|
614 |
logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
|
615 |
input_path = await save_upload_file(file, prefix="separate_in_")
|
616 |
background_tasks.add_task(cleanup_file, input_path)
|
617 |
stem_output_paths: Dict[str, str] = {}
|
618 |
-
zip_buffer = None
|
619 |
|
620 |
try:
|
621 |
-
# Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
|
622 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
623 |
-
|
624 |
logger.info("Submitting separation task to background thread...")
|
625 |
all_separated_stems_tensors = await asyncio.to_thread(
|
626 |
_run_separation_sync, loaded_model, audio_tensor, current_sr
|
627 |
)
|
628 |
logger.info("Separation task completed.")
|
629 |
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
background_tasks.add_task(cleanup_file, stem_path)
|
643 |
-
|
644 |
-
# Use a simpler archive name within the zip
|
645 |
-
archive_name = f"{stem_name}.{output_format}"
|
646 |
-
zipf.write(stem_path, arcname=archive_name)
|
647 |
-
logger.info(f"Added '{archive_name}' to ZIP.")
|
648 |
-
else:
|
649 |
-
logger.warning(f"Requested stem '{stem_name}' not found in model output (should not happen here due to validation).")
|
650 |
-
|
651 |
zip_buffer.seek(0)
|
652 |
|
653 |
zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
|
@@ -658,23 +495,11 @@ async def separate_sources(
|
|
658 |
)
|
659 |
except Exception as e:
|
660 |
logger.error(f"Error during separation operation: {e}", exc_info=True)
|
661 |
-
#
|
|
|
|
|
662 |
for path in stem_output_paths.values(): cleanup_file(path)
|
663 |
-
if zip_buffer: zip_buffer.close() # Ensure buffer is closed on error
|
664 |
if isinstance(e, HTTPException): raise e
|
665 |
-
else: raise HTTPException(
|
666 |
-
|
667 |
-
|
668 |
-
# --- How to Run ---
|
669 |
-
# 1. Ensure FFmpeg is installed and accessible in your PATH.
|
670 |
-
# 2. Save this code as `app.py`.
|
671 |
-
# 3. Create `requirements.txt` (as shown in previous responses, including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs).
|
672 |
-
# 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!).
|
673 |
-
# 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Using --host 0.0.0.0 and port 7860 common for HF Spaces)
|
674 |
-
# Remove --reload for production/stable deployment.
|
675 |
-
#
|
676 |
-
# --- WARNING ---
|
677 |
-
# - AI models require SIGNIFICANT RAM and CPU/GPU. Inference can be SLOW.
|
678 |
-
# - The first run will download models, which can take a long time and lots of disk space.
|
679 |
-
# - Ensure the specific model IDs/names used (e.g., "htdemucs") are correct and compatible.
|
680 |
-
# - Monitor startup logs carefully for model loading success or failure.
|
|
|
5 |
import logging
|
6 |
import asyncio
|
7 |
from typing import List, Optional, Dict, Any
|
8 |
+
import traceback # For detailed error logging
|
9 |
|
10 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
|
11 |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
|
|
17 |
from pydub.exceptions import CouldntDecodeError
|
18 |
|
19 |
# --- AI & Advanced Audio Imports ---
|
20 |
+
# Add extra logging around imports
|
21 |
+
logger_init = logging.getLogger("AppInit")
|
22 |
+
logger_init.setLevel(logging.INFO)
|
23 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
24 |
+
# Create console handler and set level to info
|
25 |
+
ch = logging.StreamHandler()
|
26 |
+
ch.setLevel(logging.INFO)
|
27 |
+
ch.setFormatter(formatter)
|
28 |
+
logger_init.addHandler(ch)
|
29 |
+
|
30 |
try:
|
31 |
+
logger_init.info("Importing torch...")
|
32 |
import torch
|
33 |
+
logger_init.info("Importing soundfile...")
|
|
|
34 |
import soundfile as sf
|
35 |
+
logger_init.info("Importing numpy...")
|
36 |
import numpy as np
|
37 |
+
logger_init.info("Importing librosa...")
|
38 |
import librosa
|
39 |
+
logger_init.info("Importing speechbrain...")
|
|
|
40 |
import speechbrain.pretrained
|
41 |
+
logger_init.info("Importing demucs...")
|
42 |
import demucs.separate
|
43 |
import demucs.apply
|
44 |
+
logger_init.info("AI and advanced audio libraries imported successfully.")
|
45 |
+
AI_LIBS_AVAILABLE = True
|
46 |
except ImportError as e:
|
47 |
+
logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
|
48 |
+
logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed.")
|
49 |
+
logger_init.error("AI features will be unavailable.")
|
50 |
torch = None
|
51 |
sf = None
|
52 |
np = None
|
53 |
librosa = None
|
54 |
speechbrain = None
|
55 |
demucs = None
|
56 |
+
AI_LIBS_AVAILABLE = False
|
57 |
|
58 |
# --- Configuration & Setup ---
|
59 |
TEMP_DIR = tempfile.gettempdir()
|
60 |
os.makedirs(TEMP_DIR, exist_ok=True)
|
61 |
|
62 |
+
# Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
|
63 |
+
logger = logging.getLogger(__name__) # Will inherit root logger settings
|
64 |
|
65 |
# --- Global Variables for Loaded Models ---
|
|
|
66 |
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
|
67 |
+
SEPARATION_MODEL_KEY = "htdemucs"
|
|
|
68 |
|
69 |
enhancement_models: Dict[str, Any] = {}
|
70 |
separation_models: Dict[str, Any] = {}
|
71 |
|
72 |
+
ENHANCEMENT_SR = 16000
|
73 |
+
DEMUCS_SR = 44100
|
|
|
74 |
|
75 |
# --- Device Selection ---
|
76 |
if torch:
|
77 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
78 |
+
logger_init.info(f"Using device: {DEVICE}")
|
79 |
else:
|
80 |
+
DEVICE = "cpu"
|
81 |
+
logger_init.info("Torch not available, defaulting device to CPU.")
|
82 |
|
|
|
83 |
|
84 |
+
# --- Helper Functions (cleanup_file, save_upload_file - same as before) ---
|
85 |
def cleanup_file(file_path: str):
|
86 |
"""Safely remove a file."""
|
87 |
try:
|
88 |
if file_path and os.path.exists(file_path):
|
89 |
os.remove(file_path)
|
90 |
+
# logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
|
91 |
except Exception as e:
|
92 |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
|
93 |
|
94 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
95 |
"""Saves an uploaded file to a temporary location and returns the path."""
|
96 |
_, file_extension = os.path.splitext(upload_file.filename)
|
|
|
97 |
if not file_extension: file_extension = ".wav"
|
98 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
99 |
try:
|
100 |
with open(temp_file_path, "wb") as buffer:
|
101 |
+
while content := await upload_file.read(1024 * 1024): buffer.write(content)
|
|
|
|
|
102 |
logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
|
103 |
return temp_file_path
|
104 |
except Exception as e:
|
105 |
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
|
106 |
+
cleanup_file(temp_file_path)
|
107 |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
|
108 |
finally:
|
109 |
+
await upload_file.close()
|
110 |
|
111 |
+
# --- Audio Loading/Saving Functions (same as before) ---
|
112 |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
|
113 |
"""Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
|
114 |
+
# ... (Function definition remains the same) ...
|
115 |
try:
|
116 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
117 |
+
# logger.debug(...) # Keep debug logs if needed
|
118 |
+
if audio.ndim > 1: # Ensure mono
|
119 |
+
if audio.shape[0] < audio.shape[1] and audio.shape[0] < 10: # Check if first dim is likely channels
|
120 |
+
audio = audio[0, :]
|
121 |
+
elif audio.shape[1] < audio.shape[0] and audio.shape[1] < 10: # Check if second dim is likely channels
|
122 |
+
audio = audio[:, 0]
|
123 |
+
else: # Fallback: Average if dims are ambiguous or many channels
|
124 |
+
logger.warning(f"Ambiguous audio shape {audio.shape}, averaging channels to mono.")
|
125 |
+
audio = np.mean(audio, axis=1 if audio.shape[1] < audio.shape[0] else 0)
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
audio_tensor = torch.from_numpy(audio).float()
|
128 |
if target_sr and orig_sr != target_sr:
|
129 |
+
if librosa is None: raise RuntimeError("Librosa missing")
|
130 |
+
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
|
|
|
131 |
audio_np = audio_tensor.numpy()
|
132 |
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
|
133 |
audio_tensor = torch.from_numpy(resampled_audio_np).float()
|
134 |
current_sr = target_sr
|
|
|
135 |
else:
|
136 |
current_sr = orig_sr
|
|
|
|
|
137 |
return audio_tensor.to(DEVICE), current_sr
|
|
|
138 |
except Exception as e:
|
139 |
logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
|
|
|
140 |
cleanup_file(file_path)
|
141 |
+
raise HTTPException(status_code=415, detail=f"Could not load/process audio file: {os.path.basename(file_path)}. Check format.")
|
142 |
+
|
143 |
|
144 |
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
|
145 |
"""Saves audio data (Tensor or NumPy array) to a temporary file."""
|
146 |
+
# ... (Function definition remains the same) ...
|
147 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
|
148 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
149 |
try:
|
150 |
+
# logger.debug(...) # Keep debug logs if needed
|
|
|
|
|
151 |
if isinstance(audio_data, torch.Tensor):
|
|
|
|
|
152 |
audio_np = audio_data.detach().cpu().numpy()
|
153 |
elif isinstance(audio_data, np.ndarray):
|
154 |
audio_np = audio_data
|
155 |
else:
|
156 |
+
raise TypeError(f"Unsupported audio data type: {type(audio_data)}")
|
157 |
|
158 |
+
if audio_np.dtype != np.float32: audio_np = audio_np.astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
159 |
audio_np = np.clip(audio_np, -1.0, 1.0)
|
160 |
|
|
|
161 |
if output_format.lower() in ['wav', 'flac']:
|
162 |
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
|
163 |
else:
|
164 |
+
if audio_np.ndim > 1: audio_np_mono = np.mean(audio_np, axis=0 if audio_np.shape[0] < audio_np.shape[1] else 1) # Basic mono conversion
|
165 |
+
else: audio_np_mono = audio_np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
audio_int16 = (audio_np_mono * 32767).astype(np.int16)
|
167 |
+
segment = AudioSegment(audio_int16.tobytes(), frame_rate=sampling_rate, sample_width=audio_int16.dtype.itemsize, channels=1)
|
|
|
|
|
|
|
|
|
|
|
168 |
segment.export(output_path, format=output_format)
|
|
|
169 |
return output_path
|
170 |
except Exception as e:
|
171 |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
|
172 |
+
cleanup_file(output_path)
|
173 |
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
174 |
|
175 |
+
# --- Pydub Loading/Exporting (for basic edits - same as before) ---
|
|
|
176 |
def load_audio_pydub(file_path: str) -> AudioSegment:
|
177 |
+
# ... (Function definition remains the same) ...
|
178 |
try:
|
179 |
audio = AudioSegment.from_file(file_path)
|
|
|
180 |
return audio
|
181 |
+
except CouldntDecodeError: raise HTTPException(status_code=415, detail=f"Unsupported audio format (pydub): {os.path.basename(file_path)}")
|
182 |
+
except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing audio (pydub): {os.path.basename(file_path)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
def export_audio_pydub(audio: AudioSegment, format: str) -> str:
|
185 |
+
# ... (Function definition remains the same) ...
|
186 |
output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
|
187 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
188 |
try:
|
|
|
189 |
audio.export(output_path, format=format.lower())
|
190 |
return output_path
|
191 |
+
except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to export audio (pydub): {format}")
|
|
|
|
|
|
|
192 |
|
193 |
|
194 |
+
# --- Synchronous AI Inference Functions (same as before) ---
|
|
|
195 |
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
|
196 |
+
# ... (Function definition remains the same) ...
|
197 |
if not model: raise ValueError("Enhancement model not loaded")
|
198 |
try:
|
199 |
+
logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
|
200 |
+
model_device = next(model.parameters()).device
|
201 |
+
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
|
202 |
+
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
with torch.no_grad():
|
204 |
enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
|
|
|
|
|
205 |
enhanced_audio = enhanced_tensor.squeeze(0).cpu()
|
206 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
207 |
return enhanced_audio
|
208 |
+
except Exception as e: logger.error(f"Sync enhancement error: {e}", exc_info=True); raise
|
|
|
|
|
209 |
|
210 |
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
|
211 |
+
# ... (Function definition remains the same) ...
|
212 |
if not model: raise ValueError("Separation model not loaded")
|
213 |
+
if not demucs: raise RuntimeError("Demucs library missing")
|
214 |
try:
|
215 |
+
logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
|
|
|
|
|
216 |
model_device = next(model.parameters()).device
|
217 |
+
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
|
218 |
+
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
|
219 |
+
elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
if audio_tensor.shape[1] != model.audio_channels:
|
221 |
+
if audio_tensor.shape[1] == 1: audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
|
222 |
+
else: raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
with torch.no_grad():
|
224 |
+
audio_to_process = audio_tensor.squeeze(0)
|
|
|
|
|
|
|
225 |
out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25)
|
|
|
|
|
|
|
|
|
|
|
226 |
stem_map = {name: out[i] for i, name in enumerate(model.sources)}
|
227 |
+
output_stems = {name: data.mean(dim=0).detach().cpu() for name, data in stem_map.items()}
|
228 |
+
logger.info(f"Separation complete. Stems: {list(output_stems.keys())}")
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
return output_stems
|
230 |
+
except Exception as e: logger.error(f"Sync separation error: {e}", exc_info=True); raise
|
231 |
|
|
|
|
|
|
|
232 |
|
233 |
+
# --- Model Loading Function (Enhanced Logging) ---
|
234 |
def load_hf_models():
|
235 |
"""Loads AI models at startup using correct libraries."""
|
236 |
+
logger_load = logging.getLogger("ModelLoader") # Use specific logger
|
237 |
+
logger_load.setLevel(logging.INFO)
|
238 |
+
if not logger_load.handlers: logger_load.addHandler(ch) # Add handler if not already present
|
239 |
+
|
240 |
global enhancement_models, separation_models
|
241 |
+
if not AI_LIBS_AVAILABLE:
|
242 |
+
logger_load.error("Core AI libraries not available. Cannot load AI models.")
|
243 |
return
|
244 |
|
245 |
+
# --- Load Enhancement Model ---
|
246 |
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
|
247 |
+
logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
|
248 |
try:
|
249 |
+
# Log device before loading
|
250 |
+
logger_load.info(f"Attempting load on device: {DEVICE}")
|
251 |
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
|
252 |
source=enhancement_model_hparams,
|
253 |
run_opts={"device": DEVICE}
|
254 |
)
|
255 |
+
# Check model device after loading
|
256 |
+
model_device = next(enhancer.parameters()).device
|
257 |
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
|
258 |
+
logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
|
259 |
except Exception as e:
|
260 |
+
logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False) # Log only message
|
261 |
+
logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
|
262 |
+
logger_load.warning("Enhancement features will be unavailable.")
|
263 |
+
|
264 |
|
265 |
+
# --- Load Separation Model ---
|
266 |
separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs"
|
267 |
+
logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
|
268 |
try:
|
269 |
+
logger_load.info(f"Attempting load on device: {DEVICE}")
|
270 |
+
# This automatically handles downloading the model checkpoint
|
271 |
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
|
272 |
+
model_device = next(separator.parameters()).device
|
273 |
separation_models[SEPARATION_MODEL_KEY] = separator
|
274 |
+
logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
|
275 |
+
logger_load.info(f"Separation model sources: {separator.sources}")
|
276 |
except Exception as e:
|
277 |
+
logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
|
278 |
+
logger_load.error(f"Traceback: {traceback.format_exc()}")
|
279 |
+
logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
|
280 |
+
logger_load.warning("Separation features will be unavailable.")
|
281 |
+
|
282 |
+
logger_load.info(f"--- Model loading attempts finished ---")
|
283 |
+
logger_load.info(f"Loaded Enhancement Models: {list(enhancement_models.keys())}")
|
284 |
+
logger_load.info(f"Loaded Separation Models: {list(separation_models.keys())}")
|
285 |
|
286 |
|
287 |
# --- FastAPI App ---
|
288 |
app = FastAPI(
|
289 |
title="AI Audio Editor API",
|
290 |
+
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
|
291 |
+
version="2.1.1", # Incremented version
|
292 |
)
|
293 |
|
294 |
@app.on_event("startup")
|
295 |
async def startup_event():
|
296 |
+
# Use the init logger for startup messages
|
297 |
+
logger_init.info("--- FastAPI Application Startup ---")
|
298 |
+
if AI_LIBS_AVAILABLE:
|
299 |
+
logger_init.info("AI Libraries appear to be available. Proceeding to load models in background thread...")
|
300 |
+
# Run blocking model load in thread
|
301 |
+
await asyncio.to_thread(load_hf_models)
|
302 |
+
logger_init.info("Background model loading task finished (check ModelLoader logs for details).")
|
303 |
+
else:
|
304 |
+
logger_init.error("AI Libraries failed to import. AI features will be disabled.")
|
305 |
+
logger_init.info("--- Startup complete ---")
|
306 |
|
307 |
# --- API Endpoints ---
|
308 |
|
|
|
311 |
"""Root endpoint providing a welcome message and available features."""
|
312 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
313 |
ai_features = []
|
314 |
+
# Check loaded models dictionary status
|
315 |
if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
|
316 |
+
if separation_models:
|
317 |
+
model = separation_models.get(SEPARATION_MODEL_KEY)
|
318 |
+
sources_str = ', '.join(model.sources) if model else 'N/A'
|
319 |
+
ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY}, sources: {sources_str})")
|
320 |
|
321 |
return {
|
322 |
"message": "Welcome to the AI Audio Editor API.",
|
323 |
+
"status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
|
324 |
+
"loaded_enhancement_models": list(enhancement_models.keys()),
|
325 |
+
"loaded_separation_models": list(separation_models.keys()),
|
326 |
"basic_features": features,
|
327 |
"ai_features": ai_features if ai_features else "None available (check startup logs)",
|
328 |
+
"notes": "Requires FFmpeg. AI features require models to load successfully at startup."
|
329 |
}
|
330 |
|
|
|
331 |
|
332 |
+
# --- Basic Editing Endpoints ---
|
333 |
+
# (Add /trim, /concat, /volume, /convert endpoints here - unchanged)
|
334 |
@app.post("/trim", tags=["Basic Editing"])
|
335 |
+
async def trim_audio( background_tasks: BackgroundTasks, file: UploadFile = File(...), start_ms: int = Form(...), end_ms: int = Form(...)):
|
336 |
+
if start_ms < 0 or end_ms <= start_ms: raise HTTPException(422, "Invalid start/end times.")
|
337 |
+
input_path = await save_upload_file(file, "trim_in_")
|
338 |
+
background_tasks.add_task(cleanup_file, input_path); output_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
try:
|
340 |
audio = load_audio_pydub(input_path)
|
341 |
trimmed_audio = audio[start_ms:end_ms]
|
342 |
+
fmt = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
|
343 |
+
output_path = export_audio_pydub(trimmed_audio, fmt)
|
344 |
+
background_tasks.add_task(cleanup_file, output_path)
|
345 |
+
fname = f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{fmt}"
|
346 |
+
return FileResponse(output_path, media_type=f"audio/{fmt}", filename=fname)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
except Exception as e:
|
|
|
348 |
if output_path: cleanup_file(output_path)
|
349 |
+
if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Trim error: {e}")
|
|
|
350 |
|
351 |
@app.post("/concat", tags=["Basic Editing"])
|
352 |
+
async def concatenate_audio( background_tasks: BackgroundTasks, files: List[UploadFile] = File(...), output_format: str = Form("mp3")):
|
353 |
+
if len(files) < 2: raise HTTPException(422, "Need at least two files.")
|
354 |
+
input_paths, loaded_audios, output_path = [], [], None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
try:
|
356 |
+
combined = None
|
357 |
for file in files:
|
358 |
+
ip = await save_upload_file(file, "concat_in_")
|
359 |
+
input_paths.append(ip); background_tasks.add_task(cleanup_file, ip)
|
360 |
+
audio = load_audio_pydub(ip)
|
361 |
+
combined = (combined + audio) if combined else audio
|
362 |
+
if not combined: raise ValueError("No audio loaded.")
|
363 |
+
output_path = export_audio_pydub(combined, output_format)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
background_tasks.add_task(cleanup_file, output_path)
|
365 |
+
fname = f"concat_{os.path.splitext(files[0].filename)[0]}_{len(files)-1}_others.{output_format}"
|
366 |
+
return FileResponse(output_path, media_type=f"audio/{output_format}", filename=fname)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
except Exception as e:
|
368 |
+
for p in input_paths: cleanup_file(p);
|
|
|
|
|
369 |
if output_path: cleanup_file(output_path)
|
370 |
+
if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Concat error: {e}")
|
|
|
371 |
|
372 |
@app.post("/volume", tags=["Basic Editing"])
|
373 |
+
async def change_volume( background_tasks: BackgroundTasks, file: UploadFile = File(...), change_db: float = Form(...)):
|
374 |
+
input_path = await save_upload_file(file, "volume_in_")
|
375 |
+
background_tasks.add_task(cleanup_file, input_path); output_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
try:
|
377 |
audio = load_audio_pydub(input_path)
|
378 |
+
adjusted = audio + change_db
|
379 |
+
fmt = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
|
380 |
+
output_path = export_audio_pydub(adjusted, fmt)
|
|
|
|
|
|
|
|
|
381 |
background_tasks.add_task(cleanup_file, output_path)
|
382 |
+
fname = f"volume_{change_db}dB_{os.path.splitext(file.filename)[0]}.{fmt}"
|
383 |
+
return FileResponse(output_path, media_type=f"audio/{fmt}", filename=fname)
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
except Exception as e:
|
|
|
385 |
if output_path: cleanup_file(output_path)
|
386 |
+
if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Volume error: {e}")
|
|
|
387 |
|
388 |
@app.post("/convert", tags=["Basic Editing"])
|
389 |
+
async def convert_format( background_tasks: BackgroundTasks, file: UploadFile = File(...), output_format: str = Form(...)):
|
390 |
+
allowed = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus'}
|
391 |
+
if output_format.lower() not in allowed: raise HTTPException(422, f"Invalid format. Allowed: {allowed}")
|
392 |
+
input_path = await save_upload_file(file, "convert_in_")
|
393 |
+
background_tasks.add_task(cleanup_file, input_path); output_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
try:
|
|
|
395 |
audio = load_audio_pydub(input_path)
|
|
|
396 |
output_path = export_audio_pydub(audio, output_format.lower())
|
397 |
background_tasks.add_task(cleanup_file, output_path)
|
398 |
+
fname = f"{os.path.splitext(file.filename)[0]}_converted.{output_format.lower()}"
|
399 |
+
return FileResponse(output_path, media_type=f"audio/{output_format.lower()}", filename=fname)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
except Exception as e:
|
|
|
401 |
if output_path: cleanup_file(output_path)
|
402 |
+
if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Convert error: {e}")
|
|
|
403 |
|
404 |
|
405 |
+
# --- AI Endpoints (Unchanged Functionality, relies on successful loading) ---
|
406 |
|
407 |
@app.post("/enhance", tags=["AI Editing"])
|
408 |
async def enhance_speech(
|
|
|
412 |
output_format: str = Form("wav", description="Output format (wav, flac recommended).")
|
413 |
):
|
414 |
"""Enhances speech audio using a pre-loaded SpeechBrain model."""
|
415 |
+
if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
|
|
|
416 |
if model_key not in enhancement_models:
|
417 |
logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
|
418 |
+
raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server startup logs.")
|
419 |
|
420 |
loaded_model = enhancement_models[model_key]
|
|
|
421 |
logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
|
422 |
input_path = await save_upload_file(file, prefix="enhance_in_")
|
423 |
background_tasks.add_task(cleanup_file, input_path)
|
424 |
output_path = None
|
|
|
425 |
try:
|
|
|
426 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
|
|
427 |
logger.info("Submitting enhancement task to background thread...")
|
428 |
enhanced_audio_tensor = await asyncio.to_thread(
|
429 |
_run_enhancement_sync, loaded_model, audio_tensor, current_sr
|
430 |
)
|
431 |
logger.info("Enhancement task completed.")
|
|
|
|
|
432 |
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
|
433 |
background_tasks.add_task(cleanup_file, output_path)
|
|
|
434 |
output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
|
435 |
+
return FileResponse(path=output_path, media_type=f"audio/{output_format}", filename=output_filename)
|
|
|
|
|
|
|
|
|
436 |
except Exception as e:
|
437 |
logger.error(f"Error during enhancement operation: {e}", exc_info=True)
|
438 |
+
if output_path: cleanup_file(output_path)
|
439 |
+
if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Enhancement error: {e}")
|
|
|
440 |
|
441 |
|
442 |
@app.post("/separate", tags=["AI Editing"])
|
|
|
448 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
449 |
):
|
450 |
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
|
451 |
+
if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
|
|
|
452 |
if model_key not in separation_models:
|
453 |
logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
|
454 |
+
raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server startup logs.")
|
455 |
|
456 |
loaded_model = separation_models[model_key]
|
457 |
+
valid_stems = set(loaded_model.sources)
|
458 |
requested_stems = set(s.lower() for s in stems)
|
459 |
if not requested_stems.issubset(valid_stems):
|
460 |
+
raise HTTPException(422, f"Invalid stem(s). Model '{model_key}' provides: {valid_stems}")
|
461 |
|
462 |
logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
|
463 |
input_path = await save_upload_file(file, prefix="separate_in_")
|
464 |
background_tasks.add_task(cleanup_file, input_path)
|
465 |
stem_output_paths: Dict[str, str] = {}
|
466 |
+
zip_buffer = io.BytesIO(); zipf = None # Define zipf here
|
467 |
|
468 |
try:
|
|
|
469 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
|
|
470 |
logger.info("Submitting separation task to background thread...")
|
471 |
all_separated_stems_tensors = await asyncio.to_thread(
|
472 |
_run_separation_sync, loaded_model, audio_tensor, current_sr
|
473 |
)
|
474 |
logger.info("Separation task completed.")
|
475 |
|
476 |
+
zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
|
477 |
+
for stem_name in requested_stems:
|
478 |
+
if stem_name in all_separated_stems_tensors:
|
479 |
+
stem_tensor = all_separated_stems_tensors[stem_name]
|
480 |
+
stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
|
481 |
+
stem_output_paths[stem_name] = stem_path
|
482 |
+
background_tasks.add_task(cleanup_file, stem_path) # Cleanup after response sent
|
483 |
+
archive_name = f"{stem_name}.{output_format}"
|
484 |
+
zipf.write(stem_path, arcname=archive_name)
|
485 |
+
logger.info(f"Added '{archive_name}' to ZIP.")
|
486 |
+
|
487 |
+
zipf.close() # Close zip file BEFORE seeking/reading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
zip_buffer.seek(0)
|
489 |
|
490 |
zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
|
|
|
495 |
)
|
496 |
except Exception as e:
|
497 |
logger.error(f"Error during separation operation: {e}", exc_info=True)
|
498 |
+
# Ensure buffer/zipfile are closed and temp files cleaned up on error
|
499 |
+
if zipf: zipf.close() # Ensure zipfile is closed
|
500 |
+
if zip_buffer: zip_buffer.close()
|
501 |
for path in stem_output_paths.values(): cleanup_file(path)
|
|
|
502 |
if isinstance(e, HTTPException): raise e
|
503 |
+
else: raise HTTPException(500, f"Separation error: {e}")
|
504 |
+
|
505 |
+
# --- (How to Run instructions remain the same) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|