Athspi commited on
Commit
d00fd38
·
verified ·
1 Parent(s): 2c84da8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -751
app.py CHANGED
@@ -1,889 +1,441 @@
1
- # app.py
2
  import os
3
  import uuid
4
  import tempfile
5
  import logging
6
- import asyncio
7
- from typing import List, Optional, Dict, Any
8
- import traceback # For detailed error logging
9
 
10
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
11
- from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
12
- import io
13
- import zipfile
14
-
15
- # --- Basic Editing Imports ---
16
  from pydub import AudioSegment
17
  from pydub.exceptions import CouldntDecodeError
18
 
19
- # --- AI & Advanced Audio Imports ---
20
- # Add extra logging around imports
21
- logger_init = logging.getLogger("AppInit")
22
- logger_init.setLevel(logging.INFO)
23
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
24
- # Create console handler and set level to info
25
- ch = logging.StreamHandler()
26
- ch.setLevel(logging.INFO)
27
- ch.setFormatter(formatter)
28
- # Avoid adding handler multiple times if script reloads
29
- if not logger_init.handlers:
30
- logger_init.addHandler(ch)
31
-
32
- AI_LIBS_AVAILABLE = False
33
  try:
34
- logger_init.info("Importing torch...")
35
- import torch
36
- logger_init.info("Importing soundfile...")
37
- import soundfile as sf
38
- logger_init.info("Importing numpy...")
39
- import numpy as np
40
- logger_init.info("Importing librosa...")
41
- import librosa
42
- logger_init.info("Importing speechbrain...")
43
- import speechbrain.pretrained
44
- logger_init.info("Importing demucs...")
45
- import demucs.separate
46
- import demucs.apply
47
- logger_init.info("AI and advanced audio libraries imported successfully.")
48
- AI_LIBS_AVAILABLE = True
49
- except ImportError as e:
50
- logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
51
- logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed correctly.")
52
- logger_init.error("AI features will be unavailable.")
53
- # Define placeholders so the rest of the code doesn't break completely on import error
54
- torch = None
55
- sf = None
56
- np = None
57
- librosa = None
58
- speechbrain = None
59
- demucs = None
60
 
61
  # --- Configuration & Setup ---
62
  TEMP_DIR = tempfile.gettempdir()
63
- # Attempt to create temp dir if it doesn't exist (useful in some environments)
64
- try:
65
- os.makedirs(TEMP_DIR, exist_ok=True)
66
- except OSError as e:
67
- logger_init.error(f"Could not create temporary directory {TEMP_DIR}: {e}")
68
- # Fallback or raise an error depending on desired behavior
69
- TEMP_DIR = "." # Use current directory as fallback (less ideal)
70
- logger_init.warning(f"Using current directory '{TEMP_DIR}' for temporary files.")
71
 
72
-
73
- # Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
74
- # This logger will be used by endpoint handlers
75
  logger = logging.getLogger(__name__)
76
 
77
- # --- Global Variables for Loaded Models ---
78
- ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
79
- # Choose a default Demucs model (htdemucs is good quality)
80
- SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
81
-
82
- enhancement_models: Dict[str, Any] = {}
83
- separation_models: Dict[str, Any] = {}
84
-
85
- # Target sampling rates (confirm from model specifics if necessary)
86
- ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
87
- DEMUCS_SR = 44100 # Demucs default is 44.1kHz
88
-
89
- # --- Device Selection ---
90
- if AI_LIBS_AVAILABLE and torch:
91
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
92
- logger_init.info(f"Selected device for AI models: {DEVICE}")
93
- else:
94
- DEVICE = "cpu" # Fallback if torch failed import
95
- logger_init.info("Torch not available or AI libs failed import, defaulting device to CPU.")
96
 
 
 
 
 
 
 
97
 
98
- # --- Helper Functions ---
99
 
100
- def cleanup_file(file_path: str):
101
- """Safely remove a file."""
102
  try:
103
- if file_path and isinstance(file_path, str) and os.path.exists(file_path):
104
- os.remove(file_path)
105
- # logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
 
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
- # Log error but don't crash the cleanup process for other files
108
- logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
109
 
110
- async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
111
  """Saves an uploaded file to a temporary location and returns the path."""
112
- if not upload_file or not upload_file.filename:
113
- raise HTTPException(status_code=400, detail="Invalid file upload object.")
114
-
115
- _, file_extension = os.path.splitext(upload_file.filename)
116
- # Default to .wav if no extension, as it's widely compatible for loading
117
- if not file_extension: file_extension = ".wav"
118
- temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
119
 
120
  try:
121
- logger.debug(f"Attempting to save uploaded file to: {temp_file_path}")
122
  with open(temp_file_path, "wb") as buffer:
123
- # Read chunk by chunk for large files
124
- while content := await upload_file.read(1024 * 1024): # 1MB chunks
125
- buffer.write(content)
126
- logger.info(f"Saved uploaded file '{upload_file.filename}' ({upload_file.content_type}) to temp path: {temp_file_path}")
127
  return temp_file_path
128
  except Exception as e:
129
- logger.error(f"Failed to save uploaded file '{upload_file.filename}' to {temp_file_path}: {e}", exc_info=True)
130
- cleanup_file(temp_file_path) # Attempt cleanup if saving failed
131
  raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
132
  finally:
133
- # Ensure file is closed even if saving fails mid-way
134
- try:
135
- await upload_file.close()
136
- except Exception:
137
- pass # Ignore errors during close if already failed
138
-
139
-
140
- # --- Audio Loading/Saving Functions ---
141
-
142
- def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
143
- """Loads audio using soundfile, converts to mono float32 Torch tensor, optionally resamples."""
144
- if not AI_LIBS_AVAILABLE:
145
- raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
146
- if not os.path.exists(file_path):
147
- raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found at {file_path}")
148
-
149
- try:
150
- audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
151
- logger.info(f"Loaded '{os.path.basename(file_path)}' - SR={orig_sr}, Shape={audio.shape}, dtype={audio.dtype}")
152
-
153
- # Ensure mono
154
- if audio.ndim > 1:
155
- # Check which dimension is smaller (likely channels)
156
- channel_dim = np.argmin(audio.shape)
157
- if audio.shape[channel_dim] > 1 and audio.shape[channel_dim] < 10: # Heuristic: <10 channels
158
- logger.info(f"Detected {audio.shape[channel_dim]} channels. Converting to mono by averaging axis {channel_dim}.")
159
- audio = np.mean(audio, axis=channel_dim)
160
- else: # Fallback or if shape is ambiguous (e.g., very short stereo)
161
- logger.warning(f"Audio has shape {audio.shape}. Taking first channel/element assuming mono or channel-first.")
162
- audio = audio[0] if channel_dim == 0 else audio[:, 0] # Select first index of the likely channel dimension
163
-
164
- logger.debug(f"Shape after mono conversion: {audio.shape}")
165
-
166
-
167
- # Ensure it's now 1D
168
- audio = audio.flatten()
169
-
170
- # Convert numpy array to torch tensor
171
- audio_tensor = torch.from_numpy(audio).float()
172
-
173
- # Resample if necessary using librosa
174
- current_sr = orig_sr
175
- if target_sr and orig_sr != target_sr:
176
- if librosa is None: raise RuntimeError("Librosa missing for resampling")
177
- logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
178
- # Librosa works on numpy
179
- audio_np = audio_tensor.numpy()
180
- resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr, res_type='kaiser_best') # Specify resampling type
181
- audio_tensor = torch.from_numpy(resampled_audio_np).float()
182
- current_sr = target_sr
183
- logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
184
-
185
- # Ensure tensor is on the correct device
186
- return audio_tensor.to(DEVICE), current_sr
187
-
188
- except sf.SoundFileError as sf_err:
189
- logger.error(f"SoundFileError loading {file_path}: {sf_err}", exc_info=True)
190
- cleanup_file(file_path)
191
- raise HTTPException(status_code=415, detail=f"Could not decode audio file: {os.path.basename(file_path)}. Unsupported format or corrupt file. Error: {sf_err}")
192
- except Exception as e:
193
- logger.error(f"Unexpected error loading/processing audio file {file_path} for AI: {e}", exc_info=True)
194
- cleanup_file(file_path)
195
- raise HTTPException(status_code=500, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Check server logs.")
196
-
197
-
198
- def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
199
- """Saves audio data (Tensor or NumPy array) to a temporary file."""
200
- if not AI_LIBS_AVAILABLE:
201
- raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
202
-
203
- output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
204
- output_path = os.path.join(TEMP_DIR, output_filename)
205
- try:
206
- logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
207
-
208
- # Convert tensor to numpy array if needed
209
- if isinstance(audio_data, torch.Tensor):
210
- logger.debug("Converting output tensor to NumPy array.")
211
- # Ensure tensor is on CPU before converting to numpy
212
- audio_np = audio_data.detach().cpu().numpy()
213
- elif isinstance(audio_data, np.ndarray):
214
- audio_np = audio_data
215
- else:
216
- raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
217
-
218
- # Ensure data is float32
219
- if audio_np.dtype != np.float32:
220
- logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
221
- audio_np = audio_np.astype(np.float32)
222
-
223
- # Clip values to avoid potential issues with formats expecting [-1, 1]
224
- audio_np = np.clip(audio_np, -1.0, 1.0)
225
-
226
- # Ensure audio is 1D (mono) before saving with soundfile or pydub conversion
227
- if audio_np.ndim > 1:
228
- logger.warning(f"Output audio data has {audio_np.ndim} dimensions, attempting to flatten or take first dimension.")
229
- # Try averaging channels if shape suggests stereo/multi-channel
230
- channel_dim = np.argmin(audio_np.shape)
231
- if audio_np.shape[channel_dim] > 1 and audio_np.shape[channel_dim] < 10:
232
- audio_np = np.mean(audio_np, axis=channel_dim)
233
- else: # Otherwise just flatten
234
- audio_np = audio_np.flatten()
235
-
236
-
237
- # Use soundfile (preferred for wav/flac)
238
- if output_format.lower() in ['wav', 'flac']:
239
- sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
240
- else:
241
- # For lossy formats, use pydub
242
- logger.debug(f"Using pydub to export to lossy format: {output_format}")
243
- # Scale float32 [-1, 1] to int16 for pydub
244
- audio_int16 = (audio_np * 32767).astype(np.int16)
245
- segment = AudioSegment(
246
- audio_int16.tobytes(),
247
- frame_rate=sampling_rate,
248
- sample_width=audio_int16.dtype.itemsize,
249
- channels=1 # Assuming mono after processing above
250
- )
251
- # Pydub might need explicit ffmpeg path in some envs
252
- # AudioSegment.converter = "/path/to/ffmpeg" # Uncomment and set path if needed
253
- segment.export(output_path, format=output_format)
254
-
255
- logger.info(f"Successfully saved AI audio to {output_path}")
256
- return output_path
257
- except Exception as e:
258
- logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
259
- cleanup_file(output_path) # Attempt cleanup on saving failure
260
- raise HTTPException(status_code=500, detail=f"Failed to save processed audio to format '{output_format}'.")
261
-
262
 
263
- # --- Pydub Loading/Exporting (for basic edits) ---
264
- def load_audio_pydub(file_path: str) -> AudioSegment:
265
  """Loads an audio file using pydub."""
266
- if not os.path.exists(file_path):
267
- raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found (pydub) at {file_path}")
268
  try:
269
- logger.debug(f"Loading audio with pydub: {file_path}")
270
- # Explicitly provide format if possible, helps pydub sometimes
271
- file_ext = os.path.splitext(file_path)[1][1:].lower()
272
- if file_ext:
273
- audio = AudioSegment.from_file(file_path, format=file_ext)
274
- else:
275
- audio = AudioSegment.from_file(file_path) # Let pydub detect
276
- logger.info(f"Loaded audio using pydub from: {file_path}")
277
  return audio
278
- except CouldntDecodeError as e:
279
- logger.warning(f"Pydub CouldntDecodeError for {file_path}: {e}")
280
- cleanup_file(file_path)
281
- raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}")
 
 
282
  except Exception as e:
283
- logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True)
284
- cleanup_file(file_path)
285
- raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}")
286
-
287
- def export_audio_pydub(audio: AudioSegment, format: str) -> str:
288
- """Exports a Pydub AudioSegment to a temporary file and returns the path."""
289
- output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
 
290
  output_path = os.path.join(TEMP_DIR, output_filename)
291
  try:
292
- logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}")
293
- audio.export(output_path, format=format.lower())
294
- return output_path
295
- except Exception as e:
296
- logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True)
297
- cleanup_file(output_path) # Cleanup if export failed
298
- raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.")
299
-
300
-
301
- # --- Synchronous AI Inference Functions ---
302
-
303
- def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
304
- """Synchronous wrapper for SpeechBrain enhancement model inference."""
305
- if not AI_LIBS_AVAILABLE or not model: raise ValueError("Enhancement model/libs not available")
306
- try:
307
- logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
308
- model_device = next(model.parameters()).device # Check model's current device
309
- if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
310
- # Add batch dimension if model expects it (most do)
311
- if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
312
-
313
- with torch.no_grad():
314
- # Check if model expects lengths parameter
315
- enhance_method = getattr(model, "enhance_batch", getattr(model, "forward", None))
316
- if "lengths" in enhance_method.__code__.co_varnames:
317
- enhanced_tensor = enhance_method(audio_tensor, lengths=torch.tensor([audio_tensor.shape[-1]]).to(model_device))
318
- else:
319
- enhanced_tensor = enhance_method(audio_tensor)
320
-
321
-
322
- # Remove batch dimension from output before returning, move back to CPU
323
- enhanced_audio = enhanced_tensor.squeeze(0).cpu()
324
- logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
325
- return enhanced_audio
326
- except Exception as e:
327
- logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
328
- raise # Re-raise to be caught by the async wrapper
329
-
330
- def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
331
- """Synchronous wrapper for Demucs source separation model inference."""
332
- if not AI_LIBS_AVAILABLE or not model: raise ValueError("Separation model/libs not available")
333
- if not demucs: raise RuntimeError("Demucs library missing")
334
- try:
335
- logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
336
- model_device = next(model.parameters()).device
337
- if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
338
-
339
- # Demucs expects audio as (batch, channels, samples)
340
- if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
341
- elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
342
-
343
- # Repeat channel if model expects stereo but input is mono
344
- if audio_tensor.shape[1] != model.audio_channels:
345
- if audio_tensor.shape[1] == 1:
346
- logger.info(f"Model expects {model.audio_channels} channels, input is mono. Repeating channel.")
347
- audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
348
- else:
349
- raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")
350
-
351
- logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
352
-
353
- with torch.no_grad():
354
- # Use demucs.apply.apply_model for handling chunking etc.
355
- # Requires input shape (channels, samples) - process first batch item
356
- audio_to_process = audio_tensor.squeeze(0)
357
- # Note: shifts=1, split=True are common defaults for quality
358
- out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25, progress=False) # Disable progress bar in logs
359
- # Output shape (stems, channels, samples)
360
-
361
- logger.debug(f"Raw separated sources tensor shape: {out.shape}")
362
-
363
- # Map stems based on the model's sources list
364
- stem_map = {name: out[i] for i, name in enumerate(model.sources)}
365
-
366
- # Convert back to mono for simplicity (average channels) and move to CPU
367
- output_stems = {}
368
- for name, data in stem_map.items():
369
- # Average channels, detach, move to CPU
370
- output_stems[name] = data.mean(dim=0).detach().cpu()
371
-
372
- logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
373
- return output_stems
374
-
375
- except Exception as e:
376
- logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
377
- raise
378
-
379
-
380
- # --- Model Loading Function (Enhanced Logging) ---
381
- def load_hf_models():
382
- """Loads AI models at startup using correct libraries."""
383
- logger_load = logging.getLogger("ModelLoader") # Use specific logger
384
- logger_load.setLevel(logging.INFO)
385
- # Ensure handler is attached if logger is newly created
386
- if not logger_load.handlers and ch: logger_load.addHandler(ch)
387
 
388
- global enhancement_models, separation_models
389
- if not AI_LIBS_AVAILABLE:
390
- logger_load.error("Core AI libraries not available. Cannot load AI models.")
391
- return
392
-
393
- load_success_flags = {"enhancement": False, "separation": False}
394
-
395
- # --- Load Enhancement Model ---
396
- enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
397
- logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
398
- try:
399
- logger_load.info(f"Attempting load on device: {DEVICE}")
400
- # Consider adding savedir if cache issues arise in HF Spaces
401
- # savedir_sb = os.path.join(TEMP_DIR, "speechbrain_models")
402
- # os.makedirs(savedir_sb, exist_ok=True)
403
- enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
404
- source=enhancement_model_hparams,
405
- # savedir=savedir_sb,
406
- run_opts={"device": DEVICE}
407
- )
408
- model_device = next(enhancer.parameters()).device
409
- enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
410
- logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
411
- load_success_flags["enhancement"] = True
412
- except Exception as e:
413
- logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False)
414
- logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
415
- logger_load.warning("Enhancement features will be unavailable.")
416
-
417
-
418
- # --- Load Separation Model ---
419
- separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs"
420
- logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
421
- try:
422
- logger_load.info(f"Attempting load on device: {DEVICE}")
423
- # This automatically handles downloading the model checkpoint via demucs package
424
- separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
425
- model_device = next(separator.parameters()).device
426
- separation_models[SEPARATION_MODEL_KEY] = separator
427
- logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
428
- logger_load.info(f"Separation model available sources: {separator.sources}")
429
- load_success_flags["separation"] = True
430
  except Exception as e:
431
- logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
432
- logger_load.error(f"Traceback: {traceback.format_exc()}")
433
- logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs). Check resource limits (RAM).")
434
- logger_load.warning("Separation features will be unavailable.")
435
 
436
- logger_load.info(f"--- Model loading attempts finished ---")
437
- logger_load.info(f"Enhancement Model Loaded: {load_success_flags['enhancement']}")
438
- logger_load.info(f"Separation Model Loaded: {load_success_flags['separation']}")
439
-
440
-
441
- # --- FastAPI App ---
442
- app = FastAPI(
443
- title="AI Audio Editor API",
444
- description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
445
- version="2.1.2", # Incremented version
446
- )
447
-
448
- @app.on_event("startup")
449
- async def startup_event():
450
- # Use the init logger for startup messages
451
- logger_init.info("--- FastAPI Application Startup ---")
452
- if AI_LIBS_AVAILABLE:
453
- logger_init.info("AI Libraries imported successfully. Loading models in background thread...")
454
- # Run blocking model load in thread
455
- await asyncio.to_thread(load_hf_models)
456
- logger_init.info("Background model loading task finished (check ModelLoader logs above for details).")
457
- else:
458
- logger_init.error("AI Libraries failed to import during init. AI features will be disabled.")
459
- logger_init.info("--- Startup sequence complete ---")
460
 
461
  # --- API Endpoints ---
462
 
463
  @app.get("/", tags=["General"])
464
  def read_root():
465
- """Root endpoint providing a welcome message and status of loaded models."""
466
- features = ["/trim", "/concat", "/volume", "/convert"]
467
- ai_features_status = {}
468
-
469
- if AI_LIBS_AVAILABLE:
470
- if enhancement_models:
471
- ai_features_status[ENHANCEMENT_MODEL_KEY] = "Loaded"
472
- else:
473
- ai_features_status[ENHANCEMENT_MODEL_KEY] = "Failed to load (check startup logs)"
474
-
475
- if separation_models:
476
- model = separation_models.get(SEPARATION_MODEL_KEY)
477
- sources_str = ', '.join(model.sources) if model else 'N/A'
478
- ai_features_status[SEPARATION_MODEL_KEY] = f"Loaded (Sources: {sources_str})"
479
- else:
480
- ai_features_status[SEPARATION_MODEL_KEY] = "Failed to load (check startup logs)"
481
  else:
482
- ai_features_status["AI Status"] = "Libraries Failed Import"
483
-
484
-
485
  return {
486
- "message": "Welcome to the AI Audio Editor API.",
487
- "status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
488
- "ai_models_status": ai_features_status,
489
- "basic_endpoints": features,
490
- "notes": "Requires FFmpeg. AI features require successful model loading at startup."
491
  }
492
 
 
 
 
493
 
494
- # --- Basic Editing Endpoints ---
495
-
496
- @app.post("/trim", tags=["Basic Editing"])
497
  async def trim_audio(
498
  background_tasks: BackgroundTasks,
499
  file: UploadFile = File(..., description="Audio file to trim."),
500
- start_ms: int = Form(..., ge=0, description="Start time in milliseconds."),
501
- end_ms: int = Form(..., gt=0, description="End time in milliseconds.") # Ensure end > 0
502
  ):
503
- """Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub."""
504
- if end_ms <= start_ms:
505
- raise HTTPException(status_code=422, detail="End time (end_ms) must be greater than start time (start_ms).")
506
 
507
  logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
508
- input_path = await save_upload_file(file, prefix="trim_in_")
509
- # Schedule cleanup immediately after saving, even if loading fails later
510
- background_tasks.add_task(cleanup_file, input_path)
511
- output_path = None # Define before try block
512
 
 
513
  try:
514
- audio = load_audio_pydub(input_path) # Can raise HTTPException
515
  trimmed_audio = audio[start_ms:end_ms]
516
  logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
517
 
518
- # Determine original format for export
519
- original_format = os.path.splitext(file.filename)[1][1:].lower()
520
- # Use mp3 as default only if no extension or if it's 'tmp' etc.
521
- if not original_format or len(original_format) > 5: # Basic check for valid extension length
522
- original_format = "mp3"
523
- logger.warning(f"Using default export format 'mp3' for input '{file.filename}'")
524
-
525
- output_path = export_audio_pydub(trimmed_audio, original_format) # Can raise HTTPException
526
- background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
527
 
528
- # Create a more informative filename
529
- output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}"
530
 
531
  return FileResponse(
532
  path=output_path,
533
- media_type=f"audio/{original_format}", # Best guess for media type
534
- filename=output_filename
535
  )
536
- except HTTPException as http_exc:
537
- # If load/export raised HTTPException, re-raise it
538
- # Cleanup might have already been scheduled, background tasks handle errors
539
- logger.error(f"HTTP Exception during trim: {http_exc.detail}")
540
- if output_path: cleanup_file(output_path) # Try immediate cleanup if output exists
541
- raise http_exc
542
  except Exception as e:
543
- # Catch other unexpected errors during trimming logic
544
- logger.error(f"Unexpected error during trim operation: {e}", exc_info=True)
545
- if output_path: cleanup_file(output_path)
546
- raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during trimming: {str(e)}")
 
 
547
 
548
 
549
- @app.post("/concat", tags=["Basic Editing"])
550
  async def concatenate_audio(
551
  background_tasks: BackgroundTasks,
552
  files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
553
  output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
554
  ):
555
- """Concatenates two or more audio files sequentially using Pydub."""
556
  if len(files) < 2:
557
  raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
558
 
559
  logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
560
- input_paths = [] # Keep track of all saved input file paths
561
- output_path = None # Define before try block
 
562
 
563
  try:
564
- combined_audio: Optional[AudioSegment] = None
565
- for i, file in enumerate(files):
566
- if not file or not file.filename:
567
- logger.warning(f"Skipping invalid file upload at index {i}.")
568
- continue # Skip potentially empty file entries
569
-
570
- input_path = await save_upload_file(file, prefix=f"concat_{i}_in_")
571
- input_paths.append(input_path)
572
- # Schedule cleanup for this specific input file immediately
573
- background_tasks.add_task(cleanup_file, input_path)
574
-
575
- try:
576
- audio = load_audio_pydub(input_path)
577
- if combined_audio is None:
578
- combined_audio = audio
579
- logger.info(f"Starting concatenation with '{file.filename}' ({len(combined_audio)}ms)")
580
- else:
581
- logger.info(f"Adding '{file.filename}' ({len(audio)}ms)")
582
- combined_audio += audio
583
- except HTTPException as load_exc:
584
- # Log error but continue trying to load other files if possible
585
- logger.error(f"Failed to load file '{file.filename}' for concatenation: {load_exc.detail}. Skipping this file.")
586
- except Exception as load_exc:
587
- logger.error(f"Unexpected error loading file '{file.filename}' for concatenation: {load_exc}. Skipping this file.", exc_info=True)
588
-
589
-
590
- if combined_audio is None:
591
- raise HTTPException(status_code=400, detail="No valid audio files could be loaded and combined.")
592
-
593
- logger.info(f"Final concatenated audio length: {len(combined_audio)}ms")
594
-
595
- output_path = export_audio_pydub(combined_audio, output_format) # Can raise HTTPException
596
- background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
597
-
598
- # Determine a reasonable output filename
599
- first_valid_filename = files[0].filename if files and files[0] else "audio"
600
- first_filename_base = os.path.splitext(first_valid_filename)[0]
601
- output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
602
 
603
  return FileResponse(
604
  path=output_path,
605
  media_type=f"audio/{output_format}",
606
- filename=output_filename
607
  )
608
- except HTTPException as http_exc:
609
- # If load/export raised HTTPException, re-raise it
610
- logger.error(f"HTTP Exception during concat: {http_exc.detail}")
611
- # Cleanup for output path, inputs are handled by background tasks
612
- if output_path: cleanup_file(output_path)
613
- raise http_exc
614
  except Exception as e:
615
- # Catch other unexpected errors during combining logic
616
- logger.error(f"Unexpected error during concat operation: {e}", exc_info=True)
617
- if output_path: cleanup_file(output_path)
618
- raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during concatenation: {str(e)}")
 
619
 
620
 
621
- @app.post("/volume", tags=["Basic Editing"])
622
  async def change_volume(
623
  background_tasks: BackgroundTasks,
624
  file: UploadFile = File(..., description="Audio file to adjust volume for."),
625
- change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.")
626
  ):
627
- """Adjusts the volume of an audio file by a specified decibel amount using Pydub."""
628
  logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
629
- input_path = await save_upload_file(file, prefix="volume_in_")
630
- background_tasks.add_task(cleanup_file, input_path)
 
631
  output_path = None
 
632
  try:
633
- audio = load_audio_pydub(input_path)
634
- # Check for potential silence before applying gain
635
- if audio.dBFS == -float('inf'):
636
- logger.warning(f"Input file '{file.filename}' appears to be silent. Applying volume change may have no effect.")
637
  adjusted_audio = audio + change_db
638
  logger.info(f"Volume adjusted by {change_db}dB.")
639
 
640
- original_format = os.path.splitext(file.filename)[1][1:].lower()
641
- if not original_format or len(original_format) > 5: original_format = "mp3"
642
-
643
- output_path = export_audio_pydub(adjusted_audio, original_format)
644
- background_tasks.add_task(cleanup_file, output_path)
645
 
646
- # Create filename
647
- sign = "+" if change_db >= 0 else ""
648
- output_filename=f"volume_{sign}{change_db}dB_{os.path.splitext(file.filename)[0]}.{original_format}"
649
 
650
  return FileResponse(
651
  path=output_path,
652
  media_type=f"audio/{original_format}",
653
- filename=output_filename
654
  )
655
- except HTTPException as http_exc:
656
- logger.error(f"HTTP Exception during volume change: {http_exc.detail}")
657
- if output_path: cleanup_file(output_path)
658
- raise http_exc
659
  except Exception as e:
660
- logger.error(f"Unexpected error during volume operation: {e}", exc_info=True)
661
- if output_path: cleanup_file(output_path)
662
- raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during volume adjustment: {str(e)}")
 
663
 
664
 
665
- @app.post("/convert", tags=["Basic Editing"])
666
  async def convert_format(
667
  background_tasks: BackgroundTasks,
668
  file: UploadFile = File(..., description="Audio file to convert."),
669
  output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
670
  ):
671
- """Converts an audio file to a different format using Pydub."""
672
- # Define allowed formats explicitly
673
- allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus', 'wma', 'aiff'} # Expanded list
674
- output_format_lower = output_format.lower()
675
- if output_format_lower not in allowed_formats:
676
- raise HTTPException(status_code=422, detail=f"Invalid output format '{output_format}'. Allowed: {', '.join(sorted(list(allowed_formats)))}")
677
-
678
- logger.info(f"Convert request: file='{file.filename}', output_format='{output_format_lower}'")
679
- input_path = await save_upload_file(file, prefix="convert_in_")
680
- background_tasks.add_task(cleanup_file, input_path)
681
  output_path = None
682
- try:
683
- # Load using pydub, which handles many input formats
684
- audio = load_audio_pydub(input_path)
685
- logger.info(f"Successfully loaded '{file.filename}' for conversion.")
686
-
687
- # Export using pydub
688
- output_path = export_audio_pydub(audio, output_format_lower)
689
- background_tasks.add_task(cleanup_file, output_path)
690
- logger.info(f"Successfully exported to {output_format_lower}")
691
 
692
- # Construct new filename
 
693
  filename_base = os.path.splitext(file.filename)[0]
694
- output_filename = f"{filename_base}_converted.{output_format_lower}"
695
 
696
- # Determine media type (MIME type) - might need refinement for less common types
697
- media_type_map = {
698
- 'mp3': 'audio/mpeg', 'wav': 'audio/wav', 'ogg': 'audio/ogg',
699
- 'flac': 'audio/flac', 'aac': 'audio/aac', 'm4a': 'audio/mp4', # m4a often uses mp4 container
700
- 'opus': 'audio/opus', 'wma':'audio/x-ms-wma', 'aiff':'audio/aiff'
701
- }
702
- media_type = media_type_map.get(output_format_lower, 'application/octet-stream') # Default binary if unknown
703
 
704
  return FileResponse(
705
  path=output_path,
706
- media_type=media_type,
707
- filename=output_filename
708
  )
709
- except HTTPException as http_exc:
710
- logger.error(f"HTTP Exception during conversion: {http_exc.detail}")
711
- if output_path: cleanup_file(output_path)
712
- raise http_exc
713
  except Exception as e:
714
- logger.error(f"Unexpected error during convert operation: {e}", exc_info=True)
715
- if output_path: cleanup_file(output_path)
716
- raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during format conversion: {str(e)}")
 
717
 
718
 
719
- # --- AI Endpoints ---
720
 
721
- @app.post("/enhance", tags=["AI Editing"])
722
- async def enhance_speech(
723
  background_tasks: BackgroundTasks,
724
- file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
725
- # Keep model_key optional for now, assumes default if only one loaded
726
- model_key: Optional[str] = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use (defaults to primary)."),
727
- output_format: str = Form("wav", description="Output format (wav, flac recommended).")
728
  ):
729
- """Enhances speech audio using a pre-loaded SpeechBrain model."""
730
- if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
731
- # Use the provided key or the default
732
- actual_model_key = model_key or ENHANCEMENT_MODEL_KEY
733
- if actual_model_key not in enhancement_models:
734
- logger.error(f"Enhancement model key '{actual_model_key}' requested but model not loaded.")
735
- raise HTTPException(status_code=503, detail=f"Enhancement model '{actual_model_key}' is not loaded or available. Check server startup logs.")
736
-
737
- loaded_model = enhancement_models[actual_model_key]
738
-
739
- logger.info(f"Enhance request: file='{file.filename}', model='{actual_model_key}', format='{output_format}'")
740
- input_path = await save_upload_file(file, prefix="enhance_in_")
741
- background_tasks.add_task(cleanup_file, input_path)
742
- output_path = None
743
- try:
744
- # Load audio as tensor, ensure correct SR (16kHz)
745
- audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
746
 
747
- logger.info("Submitting enhancement task to background thread...")
748
- enhanced_audio_tensor = await asyncio.to_thread(
749
- _run_enhancement_sync, loaded_model, audio_tensor, current_sr
750
- )
751
- logger.info("Enhancement task completed.")
752
-
753
- # Save the result (tensor output from enhancer at 16kHz)
754
- output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
755
- background_tasks.add_task(cleanup_file, output_path)
756
-
757
- output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
758
- media_type = f"audio/{output_format}" # Basic media type
759
- return FileResponse(path=output_path, media_type=media_type, filename=output_filename)
760
 
761
- except HTTPException as http_exc:
762
- logger.error(f"HTTP Exception during enhancement: {http_exc.detail}")
763
- if output_path: cleanup_file(output_path)
764
- raise http_exc
765
- except Exception as e:
766
- logger.error(f"Unexpected error during enhancement operation: {e}", exc_info=True)
767
- if output_path: cleanup_file(output_path)
768
- raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during enhancement: {str(e)}")
769
 
770
-
771
- @app.post("/separate", tags=["AI Editing"])
772
- async def separate_sources(
773
- background_tasks: BackgroundTasks,
774
- file: UploadFile = File(..., description="Music audio file to separate into stems."),
775
- model_key: Optional[str] = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use (defaults to primary)."),
776
- stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
777
- output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
778
- ):
779
- """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
780
- if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
781
- actual_model_key = model_key or SEPARATION_MODEL_KEY
782
- if actual_model_key not in separation_models:
783
- logger.error(f"Separation model key '{actual_model_key}' requested but model not loaded.")
784
- raise HTTPException(status_code=503, detail=f"Separation model '{actual_model_key}' is not loaded or available. Check server startup logs.")
785
-
786
- loaded_model = separation_models[actual_model_key]
787
- valid_stems = set(loaded_model.sources)
788
- requested_stems = set(s.lower() for s in stems)
789
-
790
- # Check if *any* requested stem is valid
791
- if not requested_stems:
792
- raise HTTPException(status_code=422, detail="No stems requested for separation.")
793
- # Check if *all* requested stems are valid for this model
794
- invalid_stems = requested_stems - valid_stems
795
- if invalid_stems:
796
- raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested: {', '.join(invalid_stems)}. Model '{actual_model_key}' provides: {', '.join(valid_stems)}")
797
-
798
- logger.info(f"Separate request: file='{file.filename}', model='{actual_model_key}', stems={requested_stems}, format='{output_format}'")
799
- input_path = await save_upload_file(file, prefix="separate_in_")
800
- background_tasks.add_task(cleanup_file, input_path)
801
- stem_output_paths: Dict[str, str] = {} # Store paths of successfully saved stems
802
- zip_buffer = io.BytesIO(); zipf = None # Initialize zip buffer and file object
803
 
804
  try:
805
- # Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
806
- audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
807
-
808
- logger.info("Submitting separation task to background thread...")
809
- all_separated_stems_tensors = await asyncio.to_thread(
810
- _run_separation_sync, loaded_model, audio_tensor, current_sr
 
811
  )
812
- logger.info("Separation task completed successfully.")
813
-
814
- # --- Create ZIP file in memory ---
815
- logger.info("Creating ZIP archive in memory...")
816
- zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
817
- files_added_to_zip = 0
818
- for stem_name in requested_stems:
819
- if stem_name in all_separated_stems_tensors:
820
- stem_tensor = all_separated_stems_tensors[stem_name]
821
- stem_path = None # Define stem_path before inner try
822
- try:
823
- # Save stem temporarily (save_hf_audio handles tensor)
824
- # Use the model's native sampling rate for output (DEMUCS_SR)
825
- stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
826
- stem_output_paths[stem_name] = stem_path
827
- # Schedule cleanup AFTER zip is potentially sent
828
- background_tasks.add_task(cleanup_file, stem_path)
829
-
830
- # Use a simpler archive name within the zip
831
- archive_name = f"{stem_name}.{output_format}"
832
- zipf.write(stem_path, arcname=archive_name)
833
- files_added_to_zip += 1
834
- logger.info(f"Added '{archive_name}' to ZIP.")
835
- except Exception as save_err:
836
- # Log error saving/zipping this stem but continue with others
837
- logger.error(f"Failed to save or add stem '{stem_name}' to zip: {save_err}", exc_info=True)
838
- if stem_path: cleanup_file(stem_path) # Clean up if saved but couldn't zip
839
- else:
840
- # This case should be prevented by the earlier validation
841
- logger.warning(f"Requested stem '{stem_name}' not found in model output (validation error?).")
842
-
843
- zipf.close() # Close zip file BEFORE seeking/reading
844
- zipf = None # Clear variable to indicate closed
845
-
846
- if files_added_to_zip == 0:
847
- logger.error("Failed to add any requested stems to the ZIP archive.")
848
- raise HTTPException(status_code=500, detail="Failed to generate any of the requested stems.")
849
-
850
- zip_buffer.seek(0) # Rewind buffer pointer for reading
851
-
852
- # Create final ZIP filename
853
- zip_filename = f"separated_{actual_model_key}_{os.path.splitext(file.filename)[0]}.zip"
854
- logger.info(f"Sending ZIP file: {zip_filename}")
855
- return StreamingResponse(
856
- iter([zip_buffer.getvalue()]), # StreamingResponse needs an iterator
857
- media_type="application/zip",
858
- headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
859
  )
860
- except HTTPException as http_exc:
861
- logger.error(f"HTTP Exception during separation: {http_exc.detail}")
862
- if zipf: zipf.close() # Ensure zipfile is closed
863
- if zip_buffer: zip_buffer.close()
864
- for path in stem_output_paths.values(): cleanup_file(path) # Cleanup successful stems
865
- raise http_exc
866
  except Exception as e:
867
- logger.error(f"Unexpected error during separation operation: {e}", exc_info=True)
868
- if zipf: zipf.close()
869
- if zip_buffer: zip_buffer.close()
870
- for path in stem_output_paths.values(): cleanup_file(path)
871
- raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during separation: {str(e)}")
872
- finally:
873
- # Ensure buffer is closed if not already done
874
- if zip_buffer and not zip_buffer.closed:
875
- zip_buffer.close()
876
 
877
 
878
  # --- How to Run ---
879
- # 1. Ensure FFmpeg is installed and accessible in your PATH.
880
  # 2. Save this code as `app.py`.
881
- # 3. Create `requirements.txt` (including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs, python-multipart, protobuf).
882
- # 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!).
883
- # 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Use port 7860 for HF Spaces default, remove --reload for production).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
  #
885
- # --- WARNING ---
886
- # - AI models require SIGNIFICANT RAM (often 8GB+) and CPU/GPU. Inference can be SLOW (minutes). Free HF Spaces might time out or lack resources.
887
- # - First run downloads models (can take a long time/lots of disk space).
888
- # - Ensure model names (e.g., "htdemucs") are correct.
889
- # - MONITOR STARTUP LOGS carefully for model loading success/failure. Errors here will cause 503 errors later.
 
 
1
  import os
2
  import uuid
3
  import tempfile
4
  import logging
5
+ import shutil
6
+ from typing import List, Optional, Literal
 
7
 
8
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
9
+ from fastapi.responses import FileResponse # JSONResponse removed as not used now
 
 
 
 
10
  from pydub import AudioSegment
11
  from pydub.exceptions import CouldntDecodeError
12
 
13
+ # --- Spleeter (AI Vocal Removal) Imports ---
14
+ # Wrap in try-except to handle potential import errors gracefully
 
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
+ from spleeter.separator import Separator
17
+ from spleeter.utils import logging as spleeter_logging
18
+ spleeter_available = True
19
+ # Optional: Configure Spleeter logging level (e.g., ERROR to reduce noise)
20
+ # spleeter_logging.set_level(spleeter_logging.ERROR)
21
+ except ImportError:
22
+ spleeter_available = False
23
+ Separator = None # Define Separator as None if import fails
24
+ logging.warning("Spleeter library not found or failed to import.")
25
+ logging.warning("AI Vocal Removal endpoint (/ai/remove-vocals) will be disabled.")
26
+ logging.warning("Install spleeter: pip install spleeter")
27
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # --- Configuration & Setup ---
30
  TEMP_DIR = tempfile.gettempdir()
31
+ os.makedirs(TEMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
32
 
33
+ # Configure logging
34
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
35
  logger = logging.getLogger(__name__)
36
 
37
+ # --- Global Spleeter Separator Initialization ---
38
+ # Load the model once on startup for better request performance.
39
+ # This increases startup time and initial memory usage significantly.
40
+ # Choose the model: 2stems (vocals/accompaniment), 4stems (v/drums/bass/other), 5stems (v/d/b/piano/other)
41
+ # Using 'spleeter:2stems' - downloads model on first use if not cached.
42
+ spleeter_separator: Optional[Separator] = None
43
+ if spleeter_available:
44
+ try:
45
+ logger.info("Initializing Spleeter Separator (Model: spleeter:2stems)... This may download model files.")
46
+ # MWF = Multi-channel Wiener Filtering (can improve quality but slower)
47
+ spleeter_separator = Separator('spleeter:2stems', mwf=False)
48
+ logger.info("Spleeter Separator initialized successfully.")
49
+ except Exception as e:
50
+ logger.error(f"FATAL: Failed to initialize Spleeter Separator: {e}", exc_info=True)
51
+ logger.error("AI Vocal Removal endpoint will likely fail.")
52
+ spleeter_separator = None # Ensure it's None if init failed
 
 
 
53
 
54
+ # --- FastAPI App Initialization ---
55
+ app = FastAPI(
56
+ title="Advanced Audio Editor API",
57
+ description="API for audio editing (trim, concat, volume, convert) and AI Vocal Removal (using Spleeter). Requires FFmpeg.",
58
+ version="2.0.0",
59
+ )
60
 
61
+ # --- Helper Functions (Mostly unchanged, added directory cleanup) ---
62
 
63
+ def cleanup_path(path: str):
64
+ """Safely remove a file or directory."""
65
  try:
66
+ if not path or not os.path.exists(path):
67
+ # logger.debug(f"Cleanup skipped: Path '{path}' does not exist.")
68
+ return
69
+
70
+ if os.path.isfile(path):
71
+ os.remove(path)
72
+ logger.info(f"Cleaned up temporary file: {path}")
73
+ elif os.path.isdir(path):
74
+ shutil.rmtree(path)
75
+ logger.info(f"Cleaned up temporary directory: {path}")
76
+ else:
77
+ logger.warning(f"Cleanup attempted on non-file/dir path: {path}")
78
+
79
  except Exception as e:
80
+ logger.error(f"Error cleaning up path {path}: {e}", exc_info=True)
 
81
 
82
+ async def save_upload_file(upload_file: UploadFile) -> str:
83
  """Saves an uploaded file to a temporary location and returns the path."""
84
+ file_extension = os.path.splitext(upload_file.filename)[1] or '.tmp'
85
+ # Use a subdirectory within TEMP_DIR for better organization
86
+ request_temp_dir = os.path.join(TEMP_DIR, f"audio_api_upload_{uuid.uuid4().hex}")
87
+ os.makedirs(request_temp_dir, exist_ok=True)
88
+ temp_file_path = os.path.join(request_temp_dir, f"input{file_extension}")
 
 
89
 
90
  try:
 
91
  with open(temp_file_path, "wb") as buffer:
92
+ while content := await upload_file.read(1024 * 1024):
93
+ buffer.write(content)
94
+ logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
 
95
  return temp_file_path
96
  except Exception as e:
97
+ logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
98
+ cleanup_path(request_temp_dir) # Cleanup directory if save fails
99
  raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
100
  finally:
101
+ await upload_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ def load_audio(file_path: str) -> AudioSegment:
 
104
  """Loads an audio file using pydub."""
105
+ # (Implementation unchanged)
 
106
  try:
107
+ audio = AudioSegment.from_file(file_path)
108
+ logger.info(f"Loaded audio from: {file_path} (Duration: {len(audio)}ms)")
 
 
 
 
 
 
109
  return audio
110
+ except CouldntDecodeError:
111
+ logger.warning(f"pydub couldn't decode file: {file_path}. Unsupported format or corrupted?")
112
+ raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file: {os.path.basename(file_path)}")
113
+ except FileNotFoundError:
114
+ logger.error(f"Audio file not found after saving: {file_path}")
115
+ raise HTTPException(status_code=500, detail="Internal error: Audio file disappeared.")
116
  except Exception as e:
117
+ logger.error(f"Error loading audio file {file_path}: {e}", exc_info=True)
118
+ raise HTTPException(status_code=500, detail=f"Error processing audio file: {os.path.basename(file_path)}")
119
+
120
+ def export_audio(audio: AudioSegment, desired_format: str, base_filename: str = "edited_audio") -> str:
121
+ """Exports an AudioSegment to a temporary file with specified format and returns the path."""
122
+ # (Slight modification to allow base filename)
123
+ output_filename = f"{base_filename}_{uuid.uuid4().hex}.{desired_format.lower()}"
124
+ # Place export in main TEMP_DIR, not necessarily the upload sub-dir
125
  output_path = os.path.join(TEMP_DIR, output_filename)
126
  try:
127
+ logger.info(f"Exporting audio to format '{desired_format}' at {output_path}")
128
+ # Add bitrate argument for common formats if desired (e.g., "192k" for mp3)
129
+ export_params = {}
130
+ if desired_format.lower() == "mp3":
131
+ export_params['bitrate'] = "192k" # Example bitrate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ audio.export(output_path, format=desired_format.lower(), **export_params)
134
+ return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  except Exception as e:
136
+ logger.error(f"Error exporting audio to format {desired_format}: {e}", exc_info=True)
137
+ cleanup_path(output_path)
138
+ raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{desired_format}'.")
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  # --- API Endpoints ---
142
 
143
  @app.get("/", tags=["General"])
144
  def read_root():
145
+ """Root endpoint providing a welcome message and feature status."""
146
+ features = ["Trim (/trim)", "Concatenate (/concat)", "Volume (/volume)", "Convert (/convert)"]
147
+ if spleeter_separator:
148
+ features.append("AI Vocal Removal (/ai/remove-vocals)")
 
 
 
 
 
 
 
 
 
 
 
 
149
  else:
150
+ features.append("AI Vocal Removal (Disabled - Spleeter not available)")
 
 
151
  return {
152
+ "message": "Welcome to the Advanced Audio Editor API.",
153
+ "available_features": features,
154
+ "important": "AI Vocal Removal is computationally intensive and may take significant time."
 
 
155
  }
156
 
157
+ # --- Existing Endpoints (Trim, Concat, Volume, Convert) ---
158
+ # Minor changes: Use updated cleanup_path, ensure input cleanup uses the directory
159
+ # Use updated export_audio
160
 
161
+ @app.post("/trim", tags=["Editing - Pydub"])
 
 
162
  async def trim_audio(
163
  background_tasks: BackgroundTasks,
164
  file: UploadFile = File(..., description="Audio file to trim."),
165
+ start_ms: int = Form(..., description="Start time in milliseconds."),
166
+ end_ms: int = Form(..., description="End time in milliseconds.")
167
  ):
168
+ """Trims an audio file (uses pydub)."""
169
+ if start_ms < 0 or end_ms <= start_ms:
170
+ raise HTTPException(status_code=422, detail="Invalid start/end times.")
171
 
172
  logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
173
+ input_path = await save_upload_file(file)
174
+ input_dir = os.path.dirname(input_path)
175
+ background_tasks.add_task(cleanup_path, input_dir) # Schedule input dir cleanup
 
176
 
177
+ output_path = None # Define output_path before try block
178
  try:
179
+ audio = load_audio(input_path)
180
  trimmed_audio = audio[start_ms:end_ms]
181
  logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
182
 
183
+ original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
184
+ if original_format in ["tmp", ""]: original_format = "mp3"
 
 
 
 
 
 
 
185
 
186
+ output_path = export_audio(trimmed_audio, original_format, base_filename=f"trimmed_{os.path.splitext(file.filename)[0]}")
187
+ background_tasks.add_task(cleanup_path, output_path) # Schedule output cleanup
188
 
189
  return FileResponse(
190
  path=output_path,
191
+ media_type=f"audio/{original_format}",
192
+ filename=f"trimmed_{file.filename}"
193
  )
 
 
 
 
 
 
194
  except Exception as e:
195
+ logger.error(f"Error during trim operation: {e}", exc_info=True)
196
+ # Ensure immediate cleanup on error if possible
197
+ if output_path: cleanup_path(output_path)
198
+ # Input dir cleanup is handled by background task unless error is critical before scheduling
199
+ if isinstance(e, HTTPException): raise e
200
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}")
201
 
202
 
203
+ @app.post("/concat", tags=["Editing - Pydub"])
204
  async def concatenate_audio(
205
  background_tasks: BackgroundTasks,
206
  files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
207
  output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
208
  ):
209
+ """Concatenates two or more audio files sequentially (uses pydub)."""
210
  if len(files) < 2:
211
  raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
212
 
213
  logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
214
+ input_dirs = [] # Store directories to clean up
215
+ loaded_audios = []
216
+ output_path = None
217
 
218
  try:
219
+ for file in files:
220
+ input_path = await save_upload_file(file)
221
+ input_dir = os.path.dirname(input_path)
222
+ input_dirs.append(input_dir)
223
+ background_tasks.add_task(cleanup_path, input_dir)
224
+ audio = load_audio(input_path)
225
+ loaded_audios.append(audio)
226
+
227
+ if not loaded_audios: raise ValueError("No audio segments loaded.")
228
+
229
+ combined_audio = loaded_audios[0]
230
+ for i in range(1, len(loaded_audios)):
231
+ combined_audio += loaded_audios[i]
232
+ logger.info(f"Concatenated audio length: {len(combined_audio)}ms")
233
+
234
+ first_filename_base = os.path.splitext(files[0].filename)[0]
235
+ output_base = f"concat_{first_filename_base}_and_{len(files)-1}_others"
236
+ output_path = export_audio(combined_audio, output_format, base_filename=output_base)
237
+ background_tasks.add_task(cleanup_path, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  return FileResponse(
240
  path=output_path,
241
  media_type=f"audio/{output_format}",
242
+ filename=f"{output_base}.{output_format}"
243
  )
 
 
 
 
 
 
244
  except Exception as e:
245
+ logger.error(f"Error during concat operation: {e}", exc_info=True)
246
+ if output_path: cleanup_path(output_path)
247
+ # Input dirs cleanup handled by background tasks
248
+ if isinstance(e, HTTPException): raise e
249
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}")
250
 
251
 
252
+ @app.post("/volume", tags=["Editing - Pydub"])
253
  async def change_volume(
254
  background_tasks: BackgroundTasks,
255
  file: UploadFile = File(..., description="Audio file to adjust volume for."),
256
+ change_db: float = Form(..., description="Volume change in decibels (dB). +/- values.")
257
  ):
258
+ """Adjusts audio volume (uses pydub)."""
259
  logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
260
+ input_path = await save_upload_file(file)
261
+ input_dir = os.path.dirname(input_path)
262
+ background_tasks.add_task(cleanup_path, input_dir)
263
  output_path = None
264
+
265
  try:
266
+ audio = load_audio(input_path)
 
 
 
267
  adjusted_audio = audio + change_db
268
  logger.info(f"Volume adjusted by {change_db}dB.")
269
 
270
+ original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
271
+ if original_format in ["tmp", ""]: original_format = "mp3"
 
 
 
272
 
273
+ output_base = f"volume_{change_db}dB_{os.path.splitext(file.filename)[0]}"
274
+ output_path = export_audio(adjusted_audio, original_format, base_filename=output_base)
275
+ background_tasks.add_task(cleanup_path, output_path)
276
 
277
  return FileResponse(
278
  path=output_path,
279
  media_type=f"audio/{original_format}",
280
+ filename=f"{output_base}.{original_format}" # Use correct extension
281
  )
 
 
 
 
282
  except Exception as e:
283
+ logger.error(f"Error during volume operation: {e}", exc_info=True)
284
+ if output_path: cleanup_path(output_path)
285
+ if isinstance(e, HTTPException): raise e
286
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}")
287
 
288
 
289
+ @app.post("/convert", tags=["Editing - Pydub"])
290
  async def convert_format(
291
  background_tasks: BackgroundTasks,
292
  file: UploadFile = File(..., description="Audio file to convert."),
293
  output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
294
  ):
295
+ """Converts audio file format (uses pydub)."""
296
+ allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a'}
297
+ safe_output_format = output_format.lower()
298
+ if safe_output_format not in allowed_formats:
299
+ raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed: {', '.join(allowed_formats)}")
300
+
301
+ logger.info(f"Convert request: file='{file.filename}', output_format='{safe_output_format}'")
302
+ input_path = await save_upload_file(file)
303
+ input_dir = os.path.dirname(input_path)
304
+ background_tasks.add_task(cleanup_path, input_dir)
305
  output_path = None
 
 
 
 
 
 
 
 
 
306
 
307
+ try:
308
+ audio = load_audio(input_path)
309
  filename_base = os.path.splitext(file.filename)[0]
310
+ output_base = f"{filename_base}_converted"
311
 
312
+ output_path = export_audio(audio, safe_output_format, base_filename=output_base)
313
+ background_tasks.add_task(cleanup_path, output_path)
 
 
 
 
 
314
 
315
  return FileResponse(
316
  path=output_path,
317
+ media_type=f"audio/{safe_output_format}",
318
+ filename=f"{output_base}.{safe_output_format}"
319
  )
 
 
 
 
320
  except Exception as e:
321
+ logger.error(f"Error during convert operation: {e}", exc_info=True)
322
+ if output_path: cleanup_path(output_path)
323
+ if isinstance(e, HTTPException): raise e
324
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}")
325
 
326
 
327
+ # --- AI Vocal Removal Endpoint ---
328
 
329
+ @app.post("/ai/remove-vocals", tags=["Editing - AI"])
330
+ async def ai_remove_vocals(
331
  background_tasks: BackgroundTasks,
332
+ file: UploadFile = File(..., description="Audio file containing mixed vocals and accompaniment."),
333
+ stem_to_return: Literal['accompaniment', 'vocals'] = Form("accompaniment", description="Which stem to return: 'accompaniment' (default) or 'vocals'."),
334
+ output_format: str = Form("wav", description="Output format for the separated stem (e.g., 'wav', 'mp3'). WAV recommended for quality.")
 
335
  ):
336
+ """
337
+ Separates vocals from accompaniment using Spleeter (AI model).
338
+ NOTE: This is computationally intensive and can take significant time.
339
+ """
340
+ if not spleeter_separator:
341
+ logger.warning("Vocal removal endpoint called, but Spleeter is not available.")
342
+ raise HTTPException(status_code=503, detail="AI Vocal Removal service is unavailable (Spleeter not loaded).")
 
 
 
 
 
 
 
 
 
 
343
 
344
+ logger.info(f"AI Vocal Removal request: file='{file.filename}', return='{stem_to_return}', format='{output_format}'")
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ input_path = await save_upload_file(file)
347
+ input_dir = os.path.dirname(input_path) # Directory where input was saved
348
+ spleeter_output_dir = os.path.join(TEMP_DIR, f"spleeter_out_{uuid.uuid4().hex}") # Unique output dir for Spleeter
349
+ final_output_path = None # Path to the file that will be returned
 
 
 
 
350
 
351
+ # Schedule cleanup for both input dir and potential Spleeter output dir
352
+ background_tasks.add_task(cleanup_path, input_dir)
353
+ background_tasks.add_task(cleanup_path, spleeter_output_dir) # This will be created by Spleeter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  try:
356
+ logger.info(f"Starting Spleeter separation for {input_path} into {spleeter_output_dir}...")
357
+ # Spleeter separates into the specified directory, creating <filename>/vocals.wav and <filename>/accompaniment.wav
358
+ # We pass the input *file* path and the desired *output directory* path.
359
+ spleeter_separator.separate_to_file(
360
+ input_path,
361
+ spleeter_output_dir,
362
+ codec='wav' # Spleeter defaults to WAV, ensuring consistent intermediate format
363
  )
364
+ logger.info(f"Spleeter separation completed.")
365
+
366
+ # Spleeter creates a subdirectory named after the input file (without extension)
367
+ input_filename_base = os.path.splitext(os.path.basename(input_path))[0]
368
+ stem_output_folder = os.path.join(spleeter_output_dir, input_filename_base)
369
+
370
+ # Determine the path to the requested stem file (always WAV from Spleeter)
371
+ target_stem_filename = f"{stem_to_return}.wav"
372
+ raw_stem_path = os.path.join(stem_output_folder, target_stem_filename)
373
+
374
+ if not os.path.exists(raw_stem_path):
375
+ logger.error(f"Spleeter output stem not found: {raw_stem_path}")
376
+ raise HTTPException(status_code=500, detail=f"AI separation failed: Could not find the '{stem_to_return}' stem.")
377
+
378
+ # --- Optional Conversion ---
379
+ safe_output_format = output_format.lower()
380
+ if safe_output_format == "wav":
381
+ # No conversion needed, return the direct Spleeter output
382
+ # We need to move/copy it out of the spleeter dir *or* just return it directly
383
+ # For simplicity and better cleanup, let's return it directly.
384
+ # BUT FileResponse needs the final path, and background task cleans the whole spleeter_output_dir.
385
+ # SAFER: Copy the desired file out to the main TEMP_DIR before returning.
386
+ final_output_path = os.path.join(TEMP_DIR, f"{input_filename_base}_{stem_to_return}_{uuid.uuid4().hex}.wav")
387
+ shutil.copyfile(raw_stem_path, final_output_path)
388
+ logger.info(f"Copied requested WAV stem to final output path: {final_output_path}")
389
+ background_tasks.add_task(cleanup_path, final_output_path) # Schedule cleanup for the copy
390
+
391
+ else:
392
+ # Convert the WAV stem to the desired format using pydub
393
+ logger.info(f"Loading separated '{stem_to_return}' stem for conversion to '{safe_output_format}'...")
394
+ audio_stem = load_audio(raw_stem_path) # Load the WAV stem
395
+ output_base = f"{input_filename_base}_{stem_to_return}"
396
+ final_output_path = export_audio(audio_stem, safe_output_format, base_filename=output_base)
397
+ logger.info(f"Converted stem saved to: {final_output_path}")
398
+ background_tasks.add_task(cleanup_path, final_output_path) # Schedule cleanup for converted file
399
+
400
+ # --- Return Result ---
401
+ if not final_output_path or not os.path.exists(final_output_path):
402
+ raise HTTPException(status_code=500, detail="Failed to prepare final output file after separation.")
403
+
404
+ return FileResponse(
405
+ path=final_output_path,
406
+ media_type=f"audio/{safe_output_format}", # Use the final format's media type
407
+ filename=os.path.basename(final_output_path) # Use the actual generated filename
 
 
 
408
  )
409
+
 
 
 
 
 
410
  except Exception as e:
411
+ logger.error(f"Error during AI Vocal Removal operation: {e}", exc_info=True)
412
+ if final_output_path: cleanup_path(final_output_path) # Attempt immediate cleanup if needed
413
+ # Input/Spleeter dir cleanup handled by background tasks
414
+ if isinstance(e, HTTPException): raise e
415
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during AI processing: {str(e)}")
 
 
 
 
416
 
417
 
418
  # --- How to Run ---
419
+ # 1. Make sure FFmpeg is installed and accessible in your PATH.
420
  # 2. Save this code as `app.py`.
421
+ # 3. Create `requirements.txt` (as shown above).
422
+ # 4. Install dependencies: `pip install -r requirements.txt` (THIS MAY TAKE A WHILE!)
423
+ # 5. Run the FastAPI server: `uvicorn app:app --reload`
424
+ #
425
+ # --- Example Usage (using curl) ---
426
+ #
427
+ # **AI Remove Vocals (Get Accompaniment as WAV):**
428
+ # curl -X POST "http://127.0.0.1:8000/ai/remove-vocals" \
429
+ # -F "file=@my_song_mix.mp3" \
430
+ # -F "stem_to_return=accompaniment" \
431
+ # -F "output_format=wav" \
432
+ # --output accompaniment_output.wav
433
+ #
434
+ # **AI Remove Vocals (Get Vocals as MP3):**
435
+ # curl -X POST "http://127.0.0.1:8000/ai/remove-vocals" \
436
+ # -F "file=@another_track.wav" \
437
+ # -F "stem_to_return=vocals" \
438
+ # -F "output_format=mp3" \
439
+ # --output vocals_only_output.mp3
440
  #
441
+ # (Other examples for /trim, /concat, /volume, /convert remain the same as before)