Athspi commited on
Commit
3e135af
·
verified ·
1 Parent(s): 2f8d75b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -386
app.py CHANGED
@@ -5,6 +5,7 @@ import tempfile
5
  import logging
6
  import asyncio
7
  from typing import List, Optional, Dict, Any
 
8
 
9
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
10
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
@@ -16,344 +17,292 @@ from pydub import AudioSegment
16
  from pydub.exceptions import CouldntDecodeError
17
 
18
  # --- AI & Advanced Audio Imports ---
 
 
 
 
 
 
 
 
 
 
19
  try:
 
20
  import torch
21
- # Transformers only needed if using HF pipelines directly, not for speechbrain/demucs manual loading
22
- # from transformers import pipeline
23
  import soundfile as sf
 
24
  import numpy as np
 
25
  import librosa
26
-
27
- # Specific Model Libraries
28
  import speechbrain.pretrained
 
29
  import demucs.separate
30
  import demucs.apply
31
-
32
- print("AI and advanced audio libraries loaded.")
33
  except ImportError as e:
34
- print(f"Error importing AI/Audio libraries: {e}")
35
- print("Ensure torch, soundfile, librosa, speechbrain, demucs are installed.")
36
- print("AI features will be unavailable.")
37
  torch = None
38
  sf = None
39
  np = None
40
  librosa = None
41
  speechbrain = None
42
  demucs = None
 
43
 
44
  # --- Configuration & Setup ---
45
  TEMP_DIR = tempfile.gettempdir()
46
  os.makedirs(TEMP_DIR, exist_ok=True)
47
 
48
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
49
- logger = logging.getLogger(__name__)
50
 
51
  # --- Global Variables for Loaded Models ---
52
- # Use consistent keys for storing/retrieving models
53
  ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
54
- # Choose a default Demucs model (htdemucs is good quality)
55
- SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
56
 
57
  enhancement_models: Dict[str, Any] = {}
58
  separation_models: Dict[str, Any] = {}
59
 
60
- # Target sampling rates (confirm from model specifics if necessary)
61
- ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
62
- DEMUCS_SR = 44100 # Demucs default is 44.1kHz
63
 
64
  # --- Device Selection ---
65
  if torch:
66
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
67
- logger.info(f"Using device: {DEVICE}")
68
  else:
69
- DEVICE = "cpu" # Fallback if torch failed import
 
70
 
71
- # --- Helper Functions ---
72
 
 
73
  def cleanup_file(file_path: str):
74
  """Safely remove a file."""
75
  try:
76
  if file_path and os.path.exists(file_path):
77
  os.remove(file_path)
78
- logger.info(f"Cleaned up temporary file: {file_path}")
79
  except Exception as e:
80
  logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
81
 
82
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
83
  """Saves an uploaded file to a temporary location and returns the path."""
84
  _, file_extension = os.path.splitext(upload_file.filename)
85
- # Default to .wav if no extension, as it's widely compatible for loading
86
  if not file_extension: file_extension = ".wav"
87
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
88
  try:
89
  with open(temp_file_path, "wb") as buffer:
90
- # Read chunk by chunk for large files
91
- while content := await upload_file.read(1024 * 1024): # 1MB chunks
92
- buffer.write(content)
93
  logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
94
  return temp_file_path
95
  except Exception as e:
96
  logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
97
- cleanup_file(temp_file_path) # Attempt cleanup if saving failed
98
  raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
99
  finally:
100
- await upload_file.close() # Ensure file handle is closed
101
 
102
- # --- Audio Loading/Saving for AI Models ---
103
  def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
104
  """Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
 
105
  try:
106
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
107
- logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
108
-
109
- # Ensure mono
110
- if audio.ndim > 1:
111
- if audio.shape[0] > audio.shape[1]: # Check if channels are likely the first dimension
112
- audio = audio[0, :] # Take the first channel
113
- logger.info(f"Selected first channel from multi-channel audio. New shape {audio.shape}")
114
- else: # Assume channels are the second dimension (common case)
115
- logger.info(f"Converting {audio.shape[1]} channels to mono by averaging.")
116
- audio = np.mean(audio, axis=1)
117
-
118
- # Convert numpy array to torch tensor
119
- audio_tensor = torch.from_numpy(audio).float()
120
 
121
- # Resample if necessary using librosa
122
  if target_sr and orig_sr != target_sr:
123
- if librosa is None: raise RuntimeError("Librosa is required for resampling but not installed.")
124
- logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
125
- # Librosa works on numpy
126
  audio_np = audio_tensor.numpy()
127
  resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
128
  audio_tensor = torch.from_numpy(resampled_audio_np).float()
129
  current_sr = target_sr
130
- logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
131
  else:
132
  current_sr = orig_sr
133
-
134
- # Ensure tensor is on the correct device
135
  return audio_tensor.to(DEVICE), current_sr
136
-
137
  except Exception as e:
138
  logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
139
- # Clean up the potentially corrupted saved file if loading failed
140
  cleanup_file(file_path)
141
- raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format supported by soundfile/libsndfile.")
 
142
 
143
  def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
144
  """Saves audio data (Tensor or NumPy array) to a temporary file."""
 
145
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
146
  output_path = os.path.join(TEMP_DIR, output_filename)
147
  try:
148
- logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
149
-
150
- # Convert tensor to numpy array if needed
151
  if isinstance(audio_data, torch.Tensor):
152
- logger.debug("Converting output tensor to NumPy array.")
153
- # Ensure tensor is on CPU before converting to numpy
154
  audio_np = audio_data.detach().cpu().numpy()
155
  elif isinstance(audio_data, np.ndarray):
156
  audio_np = audio_data
157
  else:
158
- raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
159
 
160
- # Ensure data is float32
161
- if audio_np.dtype != np.float32:
162
- logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
163
- audio_np = audio_np.astype(np.float32)
164
-
165
- # Clip values to avoid potential issues with formats expecting [-1, 1]
166
  audio_np = np.clip(audio_np, -1.0, 1.0)
167
 
168
- # Use soundfile (preferred for wav/flac)
169
  if output_format.lower() in ['wav', 'flac']:
170
  sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
171
  else:
172
- # For lossy formats, use pydub
173
- logger.debug(f"Using pydub to export to lossy format: {output_format}")
174
- # Scale float32 [-1, 1] to int16 for pydub
175
- # Ensure audio_np is 1D (mono) before scaling and converting
176
- if audio_np.ndim > 1:
177
- logger.warning(f"Audio data has {audio_np.ndim} dimensions, taking first dimension for pydub export.")
178
- audio_np_mono = audio_np[0] if audio_np.shape[0] < audio_np.shape[1] else audio_np[:, 0] # Basic mono conversion attempt
179
- else:
180
- audio_np_mono = audio_np
181
-
182
  audio_int16 = (audio_np_mono * 32767).astype(np.int16)
183
- segment = AudioSegment(
184
- audio_int16.tobytes(),
185
- frame_rate=sampling_rate,
186
- sample_width=audio_int16.dtype.itemsize,
187
- channels=1 # Assuming mono
188
- )
189
  segment.export(output_path, format=output_format)
190
-
191
  return output_path
192
  except Exception as e:
193
  logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
194
- cleanup_file(output_path) # Attempt cleanup on saving failure
195
  raise HTTPException(status_code=500, detail="Failed to save processed audio.")
196
 
197
-
198
- # --- Pydub Loading (for basic edits) ---
199
  def load_audio_pydub(file_path: str) -> AudioSegment:
200
- """Loads an audio file using pydub."""
201
  try:
202
  audio = AudioSegment.from_file(file_path)
203
- logger.info(f"Loaded audio using pydub from: {file_path}")
204
  return audio
205
- except CouldntDecodeError:
206
- logger.warning(f"pydub couldn't decode file: {file_path}. Might be unsupported format or corrupted.")
207
- raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}")
208
- except FileNotFoundError:
209
- logger.error(f"Audio file not found after saving (pydub): {file_path}")
210
- raise HTTPException(status_code=500, detail="Internal error: Audio file disappeared.")
211
- except Exception as e:
212
- logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True)
213
- raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}")
214
 
215
  def export_audio_pydub(audio: AudioSegment, format: str) -> str:
216
- """Exports a Pydub AudioSegment to a temporary file and returns the path."""
217
  output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
218
  output_path = os.path.join(TEMP_DIR, output_filename)
219
  try:
220
- logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}")
221
  audio.export(output_path, format=format.lower())
222
  return output_path
223
- except Exception as e:
224
- logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True)
225
- cleanup_file(output_path) # Cleanup if export failed
226
- raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.")
227
 
228
 
229
- # --- Synchronous AI Inference Functions ---
230
-
231
  def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
232
- """Synchronous wrapper for SpeechBrain enhancement model inference."""
233
  if not model: raise ValueError("Enhancement model not loaded")
234
  try:
235
- logger.info(f"Running speech enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
236
- # Add batch dimension if needed
237
- if audio_tensor.ndim == 1:
238
- audio_tensor = audio_tensor.unsqueeze(0)
239
-
240
- # Move tensor to the same device as the model
241
- model_device = next(model.parameters()).device # Check model's current device
242
- if audio_tensor.device != model_device:
243
- audio_tensor = audio_tensor.to(model_device)
244
-
245
  with torch.no_grad():
246
  enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
247
-
248
- # Remove batch dimension from output before returning, move back to CPU
249
  enhanced_audio = enhanced_tensor.squeeze(0).cpu()
250
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
251
  return enhanced_audio
252
- except Exception as e:
253
- logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
254
- raise
255
 
256
  def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
257
- """Synchronous wrapper for Demucs source separation model inference."""
258
  if not model: raise ValueError("Separation model not loaded")
259
- if not demucs: raise RuntimeError("Demucs library not available")
260
  try:
261
- logger.info(f"Running source separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
262
-
263
- # Move tensor to the same device as the model
264
  model_device = next(model.parameters()).device
265
- if audio_tensor.device != model_device:
266
- audio_tensor = audio_tensor.to(model_device)
267
-
268
- # Add batch and channel dimensions if mono (expects batch, channels, samples)
269
- if audio_tensor.ndim == 1:
270
- audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
271
- elif audio_tensor.ndim == 2: # Should be rare if loader works
272
- audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
273
-
274
- # Repeat channel if model expects stereo but input is mono
275
  if audio_tensor.shape[1] != model.audio_channels:
276
- if audio_tensor.shape[1] == 1:
277
- logger.warning(f"Model expects {model.audio_channels} channels, input is mono. Repeating channel.")
278
- audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
279
- else:
280
- # Cannot automatically handle other channel mismatches
281
- raise ValueError(f"Input audio has {audio_tensor.shape[1]} channels, but Demucs model expects {model.audio_channels}.")
282
-
283
- logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
284
-
285
  with torch.no_grad():
286
- # Use demucs.apply.apply_model for handling chunking etc.
287
- # apply_model expects a tensor of shape (channels, samples)
288
- # We process one batch item at a time if needed, but typically process the whole file
289
- audio_to_process = audio_tensor.squeeze(0) # Remove batch dim -> (channels, samples)
290
  out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25)
291
- # Output shape (stems, channels, samples)
292
-
293
- logger.debug(f"Raw separated sources tensor shape: {out.shape}")
294
-
295
- # Map stems based on the model's sources list
296
  stem_map = {name: out[i] for i, name in enumerate(model.sources)}
297
-
298
- # Convert back to mono for simplicity (average channels) and move to CPU
299
- output_stems = {}
300
- for name, data in stem_map.items():
301
- # Average channels, detach, move to CPU
302
- output_stems[name] = data.mean(dim=0).detach().cpu()
303
-
304
- logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
305
  return output_stems
 
306
 
307
- except Exception as e:
308
- logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
309
- raise
310
 
311
- # --- Model Loading Function ---
312
  def load_hf_models():
313
  """Loads AI models at startup using correct libraries."""
 
 
 
 
314
  global enhancement_models, separation_models
315
- if torch is None or speechbrain is None or demucs is None:
316
- logger.error("Core AI libraries (torch, speechbrain, demucs) not available. Skipping model loading.")
317
  return
318
 
319
- # --- Load Enhancement Model (SpeechBrain) ---
320
  enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
 
321
  try:
322
- logger.info(f"Loading enhancement model: {enhancement_model_hparams} (using SpeechBrain)...")
 
323
  enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
324
  source=enhancement_model_hparams,
325
  run_opts={"device": DEVICE}
326
  )
 
 
327
  enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
328
- logger.info(f"Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {DEVICE}.")
329
  except Exception as e:
330
- logger.error(f"Failed to load enhancement model '{enhancement_model_hparams}': {e}", exc_info=True)
 
 
 
331
 
332
- # --- Load Separation Model (Demucs) ---
333
  separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs"
 
334
  try:
335
- logger.info(f"Loading separation model: {separation_model_name} (using Demucs package)...")
 
336
  separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
 
337
  separation_models[SEPARATION_MODEL_KEY] = separator
338
- logger.info(f"Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {DEVICE}.")
339
- logger.info(f"Separation model available sources: {separator.sources}")
340
  except Exception as e:
341
- logger.error(f"Failed to load separation model '{separation_model_name}': {e}", exc_info=True)
342
- logger.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
 
 
 
 
 
 
343
 
344
 
345
  # --- FastAPI App ---
346
  app = FastAPI(
347
  title="AI Audio Editor API",
348
- description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries (torch, speechbrain, demucs).",
349
- version="2.1.0",
350
  )
351
 
352
  @app.on_event("startup")
353
  async def startup_event():
354
- logger.info("Application startup: Loading AI models...")
355
- await asyncio.to_thread(load_hf_models)
356
- logger.info("Model loading process finished (check logs for success/failure).")
 
 
 
 
 
 
 
357
 
358
  # --- API Endpoints ---
359
 
@@ -362,185 +311,98 @@ def read_root():
362
  """Root endpoint providing a welcome message and available features."""
363
  features = ["/trim", "/concat", "/volume", "/convert"]
364
  ai_features = []
 
365
  if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
366
- if separation_models: ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY}, sources: {', '.join(separation_models.get(SEPARATION_MODEL_KEY).sources)})")
 
 
 
367
 
368
  return {
369
  "message": "Welcome to the AI Audio Editor API.",
 
 
 
370
  "basic_features": features,
371
  "ai_features": ai_features if ai_features else "None available (check startup logs)",
372
- "notes": "Requires FFmpeg. AI features require specific models loaded at startup."
373
  }
374
 
375
- # --- Basic Editing Endpoints ---
376
 
 
 
377
  @app.post("/trim", tags=["Basic Editing"])
378
- async def trim_audio(
379
- background_tasks: BackgroundTasks,
380
- file: UploadFile = File(..., description="Audio file to trim."),
381
- start_ms: int = Form(..., description="Start time in milliseconds."),
382
- end_ms: int = Form(..., description="End time in milliseconds.")
383
- ):
384
- """Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub."""
385
- if start_ms < 0 or end_ms <= start_ms:
386
- raise HTTPException(status_code=422, detail="Invalid start/end times. Ensure start_ms >= 0 and end_ms > start_ms.")
387
-
388
- logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
389
- input_path = await save_upload_file(file, prefix="trim_in_")
390
- background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
391
- output_path = None
392
-
393
  try:
394
  audio = load_audio_pydub(input_path)
395
  trimmed_audio = audio[start_ms:end_ms]
396
- logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
397
-
398
- original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
399
- if not original_format or original_format == "tmp": original_format = "mp3"
400
-
401
- output_path = export_audio_pydub(trimmed_audio, original_format)
402
- background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
403
-
404
- output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}"
405
-
406
- return FileResponse(
407
- path=output_path,
408
- media_type=f"audio/{original_format}",
409
- filename=output_filename
410
- )
411
  except Exception as e:
412
- logger.error(f"Error during trim operation: {e}", exc_info=True)
413
  if output_path: cleanup_file(output_path)
414
- if isinstance(e, HTTPException): raise e
415
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}")
416
 
417
  @app.post("/concat", tags=["Basic Editing"])
418
- async def concatenate_audio(
419
- background_tasks: BackgroundTasks,
420
- files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
421
- output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
422
- ):
423
- """Concatenates two or more audio files sequentially using Pydub."""
424
- if len(files) < 2:
425
- raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
426
-
427
- logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
428
- input_paths = []
429
- loaded_audios = []
430
- output_path = None
431
-
432
  try:
 
433
  for file in files:
434
- input_path = await save_upload_file(file, prefix="concat_in_")
435
- input_paths.append(input_path)
436
- background_tasks.add_task(cleanup_file, input_path)
437
- audio = load_audio_pydub(input_path)
438
- loaded_audios.append(audio)
439
-
440
- if not loaded_audios: raise HTTPException(status_code=500, detail="No audio segments were loaded successfully.")
441
-
442
- combined_audio = loaded_audios[0]
443
- logger.info(f"Starting concatenation with first segment ({len(combined_audio)}ms)")
444
- for i in range(1, len(loaded_audios)):
445
- logger.info(f"Adding segment {i+1} ({len(loaded_audios[i])}ms)")
446
- combined_audio += loaded_audios[i]
447
-
448
- logger.info(f"Concatenated audio length: {len(combined_audio)}ms")
449
-
450
- output_path = export_audio_pydub(combined_audio, output_format)
451
  background_tasks.add_task(cleanup_file, output_path)
452
-
453
- first_filename_base = os.path.splitext(files[0].filename)[0]
454
- output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
455
-
456
- return FileResponse(
457
- path=output_path,
458
- media_type=f"audio/{output_format}",
459
- filename=output_filename
460
- )
461
  except Exception as e:
462
- logger.error(f"Error during concat operation: {e}", exc_info=True)
463
- # Cleanup intermediate files if error occurs
464
- for path in input_paths: cleanup_file(path)
465
  if output_path: cleanup_file(output_path)
466
- if isinstance(e, HTTPException): raise e
467
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}")
468
 
469
  @app.post("/volume", tags=["Basic Editing"])
470
- async def change_volume(
471
- background_tasks: BackgroundTasks,
472
- file: UploadFile = File(..., description="Audio file to adjust volume for."),
473
- change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.")
474
- ):
475
- """Adjusts the volume of an audio file by a specified decibel amount using Pydub."""
476
- logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
477
- input_path = await save_upload_file(file, prefix="volume_in_")
478
- background_tasks.add_task(cleanup_file, input_path)
479
- output_path = None
480
-
481
  try:
482
  audio = load_audio_pydub(input_path)
483
- adjusted_audio = audio + change_db
484
- logger.info(f"Volume adjusted by {change_db}dB.")
485
-
486
- original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
487
- if not original_format or original_format == "tmp": original_format = "mp3"
488
-
489
- output_path = export_audio_pydub(adjusted_audio, original_format)
490
  background_tasks.add_task(cleanup_file, output_path)
491
-
492
- output_filename=f"volume_{change_db}dB_{os.path.splitext(file.filename)[0]}.{original_format}"
493
-
494
- return FileResponse(
495
- path=output_path,
496
- media_type=f"audio/{original_format}",
497
- filename=output_filename
498
- )
499
  except Exception as e:
500
- logger.error(f"Error during volume operation: {e}", exc_info=True)
501
  if output_path: cleanup_file(output_path)
502
- if isinstance(e, HTTPException): raise e
503
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}")
504
 
505
  @app.post("/convert", tags=["Basic Editing"])
506
- async def convert_format(
507
- background_tasks: BackgroundTasks,
508
- file: UploadFile = File(..., description="Audio file to convert."),
509
- output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
510
- ):
511
- """Converts an audio file to a different format using Pydub."""
512
- allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus'} # Common formats
513
- if output_format.lower() not in allowed_formats:
514
- raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed formats: {', '.join(allowed_formats)}")
515
-
516
- logger.info(f"Convert request: file='{file.filename}', output_format='{output_format}'")
517
- input_path = await save_upload_file(file, prefix="convert_in_")
518
- background_tasks.add_task(cleanup_file, input_path)
519
- output_path = None
520
-
521
  try:
522
- # Load using pydub, which handles many input formats
523
  audio = load_audio_pydub(input_path)
524
-
525
  output_path = export_audio_pydub(audio, output_format.lower())
526
  background_tasks.add_task(cleanup_file, output_path)
527
-
528
- filename_base = os.path.splitext(file.filename)[0]
529
- output_filename = f"{filename_base}_converted.{output_format.lower()}"
530
-
531
- return FileResponse(
532
- path=output_path,
533
- media_type=f"audio/{output_format.lower()}", # Media type might need refinement for opus/aac/m4a
534
- filename=output_filename
535
- )
536
  except Exception as e:
537
- logger.error(f"Error during convert operation: {e}", exc_info=True)
538
  if output_path: cleanup_file(output_path)
539
- if isinstance(e, HTTPException): raise e
540
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}")
541
 
542
 
543
- # --- AI Endpoints (Corrected) ---
544
 
545
  @app.post("/enhance", tags=["AI Editing"])
546
  async def enhance_speech(
@@ -550,44 +412,31 @@ async def enhance_speech(
550
  output_format: str = Form("wav", description="Output format (wav, flac recommended).")
551
  ):
552
  """Enhances speech audio using a pre-loaded SpeechBrain model."""
553
- if torch is None or speechbrain is None:
554
- raise HTTPException(status_code=501, detail="AI processing libraries (torch, speechbrain) not available.")
555
  if model_key not in enhancement_models:
556
  logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
557
- raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
558
 
559
  loaded_model = enhancement_models[model_key]
560
-
561
  logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
562
  input_path = await save_upload_file(file, prefix="enhance_in_")
563
  background_tasks.add_task(cleanup_file, input_path)
564
  output_path = None
565
-
566
  try:
567
- # Load audio as tensor, ensure correct SR (16kHz)
568
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
569
-
570
  logger.info("Submitting enhancement task to background thread...")
571
  enhanced_audio_tensor = await asyncio.to_thread(
572
  _run_enhancement_sync, loaded_model, audio_tensor, current_sr
573
  )
574
  logger.info("Enhancement task completed.")
575
-
576
- # Save the result (tensor output from enhancer at 16kHz)
577
  output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
578
  background_tasks.add_task(cleanup_file, output_path)
579
-
580
  output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
581
- return FileResponse(
582
- path=output_path,
583
- media_type=f"audio/{output_format}",
584
- filename=output_filename
585
- )
586
  except Exception as e:
587
  logger.error(f"Error during enhancement operation: {e}", exc_info=True)
588
- if output_path: cleanup_file(output_path) # Cleanup output if error occurs after save
589
- if isinstance(e, HTTPException): raise e
590
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
591
 
592
 
593
  @app.post("/separate", tags=["AI Editing"])
@@ -599,55 +448,43 @@ async def separate_sources(
599
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
600
  ):
601
  """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
602
- if torch is None or demucs is None:
603
- raise HTTPException(status_code=501, detail="AI processing libraries (torch, demucs) not available.")
604
  if model_key not in separation_models:
605
  logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
606
- raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
607
 
608
  loaded_model = separation_models[model_key]
609
- valid_stems = set(loaded_model.sources) # Get stems directly from loaded model
610
  requested_stems = set(s.lower() for s in stems)
611
  if not requested_stems.issubset(valid_stems):
612
- raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Model '{model_key}' provides: {', '.join(valid_stems)}")
613
 
614
  logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
615
  input_path = await save_upload_file(file, prefix="separate_in_")
616
  background_tasks.add_task(cleanup_file, input_path)
617
  stem_output_paths: Dict[str, str] = {}
618
- zip_buffer = None
619
 
620
  try:
621
- # Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
622
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
623
-
624
  logger.info("Submitting separation task to background thread...")
625
  all_separated_stems_tensors = await asyncio.to_thread(
626
  _run_separation_sync, loaded_model, audio_tensor, current_sr
627
  )
628
  logger.info("Separation task completed.")
629
 
630
- # --- Create ZIP file in memory ---
631
- zip_buffer = io.BytesIO()
632
- with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
633
- # Save only the requested stems
634
- for stem_name in requested_stems:
635
- if stem_name in all_separated_stems_tensors:
636
- stem_tensor = all_separated_stems_tensors[stem_name]
637
- # Save stem temporarily (save_hf_audio handles tensor)
638
- # Use the model's native sampling rate for output (DEMUCS_SR)
639
- stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
640
- stem_output_paths[stem_name] = stem_path
641
- # Schedule cleanup AFTER zip is sent
642
- background_tasks.add_task(cleanup_file, stem_path)
643
-
644
- # Use a simpler archive name within the zip
645
- archive_name = f"{stem_name}.{output_format}"
646
- zipf.write(stem_path, arcname=archive_name)
647
- logger.info(f"Added '{archive_name}' to ZIP.")
648
- else:
649
- logger.warning(f"Requested stem '{stem_name}' not found in model output (should not happen here due to validation).")
650
-
651
  zip_buffer.seek(0)
652
 
653
  zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
@@ -658,23 +495,11 @@ async def separate_sources(
658
  )
659
  except Exception as e:
660
  logger.error(f"Error during separation operation: {e}", exc_info=True)
661
- # Manually trigger cleanup for any stems saved before error
 
 
662
  for path in stem_output_paths.values(): cleanup_file(path)
663
- if zip_buffer: zip_buffer.close() # Ensure buffer is closed on error
664
  if isinstance(e, HTTPException): raise e
665
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
666
-
667
-
668
- # --- How to Run ---
669
- # 1. Ensure FFmpeg is installed and accessible in your PATH.
670
- # 2. Save this code as `app.py`.
671
- # 3. Create `requirements.txt` (as shown in previous responses, including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs).
672
- # 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!).
673
- # 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Using --host 0.0.0.0 and port 7860 common for HF Spaces)
674
- # Remove --reload for production/stable deployment.
675
- #
676
- # --- WARNING ---
677
- # - AI models require SIGNIFICANT RAM and CPU/GPU. Inference can be SLOW.
678
- # - The first run will download models, which can take a long time and lots of disk space.
679
- # - Ensure the specific model IDs/names used (e.g., "htdemucs") are correct and compatible.
680
- # - Monitor startup logs carefully for model loading success or failure.
 
5
  import logging
6
  import asyncio
7
  from typing import List, Optional, Dict, Any
8
+ import traceback # For detailed error logging
9
 
10
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
11
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
 
17
  from pydub.exceptions import CouldntDecodeError
18
 
19
  # --- AI & Advanced Audio Imports ---
20
+ # Add extra logging around imports
21
+ logger_init = logging.getLogger("AppInit")
22
+ logger_init.setLevel(logging.INFO)
23
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
24
+ # Create console handler and set level to info
25
+ ch = logging.StreamHandler()
26
+ ch.setLevel(logging.INFO)
27
+ ch.setFormatter(formatter)
28
+ logger_init.addHandler(ch)
29
+
30
  try:
31
+ logger_init.info("Importing torch...")
32
  import torch
33
+ logger_init.info("Importing soundfile...")
 
34
  import soundfile as sf
35
+ logger_init.info("Importing numpy...")
36
  import numpy as np
37
+ logger_init.info("Importing librosa...")
38
  import librosa
39
+ logger_init.info("Importing speechbrain...")
 
40
  import speechbrain.pretrained
41
+ logger_init.info("Importing demucs...")
42
  import demucs.separate
43
  import demucs.apply
44
+ logger_init.info("AI and advanced audio libraries imported successfully.")
45
+ AI_LIBS_AVAILABLE = True
46
  except ImportError as e:
47
+ logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
48
+ logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed.")
49
+ logger_init.error("AI features will be unavailable.")
50
  torch = None
51
  sf = None
52
  np = None
53
  librosa = None
54
  speechbrain = None
55
  demucs = None
56
+ AI_LIBS_AVAILABLE = False
57
 
58
  # --- Configuration & Setup ---
59
  TEMP_DIR = tempfile.gettempdir()
60
  os.makedirs(TEMP_DIR, exist_ok=True)
61
 
62
+ # Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
63
+ logger = logging.getLogger(__name__) # Will inherit root logger settings
64
 
65
  # --- Global Variables for Loaded Models ---
 
66
  ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
67
+ SEPARATION_MODEL_KEY = "htdemucs"
 
68
 
69
  enhancement_models: Dict[str, Any] = {}
70
  separation_models: Dict[str, Any] = {}
71
 
72
+ ENHANCEMENT_SR = 16000
73
+ DEMUCS_SR = 44100
 
74
 
75
  # --- Device Selection ---
76
  if torch:
77
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
78
+ logger_init.info(f"Using device: {DEVICE}")
79
  else:
80
+ DEVICE = "cpu"
81
+ logger_init.info("Torch not available, defaulting device to CPU.")
82
 
 
83
 
84
+ # --- Helper Functions (cleanup_file, save_upload_file - same as before) ---
85
  def cleanup_file(file_path: str):
86
  """Safely remove a file."""
87
  try:
88
  if file_path and os.path.exists(file_path):
89
  os.remove(file_path)
90
+ # logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
91
  except Exception as e:
92
  logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
93
 
94
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
95
  """Saves an uploaded file to a temporary location and returns the path."""
96
  _, file_extension = os.path.splitext(upload_file.filename)
 
97
  if not file_extension: file_extension = ".wav"
98
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
99
  try:
100
  with open(temp_file_path, "wb") as buffer:
101
+ while content := await upload_file.read(1024 * 1024): buffer.write(content)
 
 
102
  logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
103
  return temp_file_path
104
  except Exception as e:
105
  logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
106
+ cleanup_file(temp_file_path)
107
  raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
108
  finally:
109
+ await upload_file.close()
110
 
111
+ # --- Audio Loading/Saving Functions (same as before) ---
112
  def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
113
  """Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
114
+ # ... (Function definition remains the same) ...
115
  try:
116
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
117
+ # logger.debug(...) # Keep debug logs if needed
118
+ if audio.ndim > 1: # Ensure mono
119
+ if audio.shape[0] < audio.shape[1] and audio.shape[0] < 10: # Check if first dim is likely channels
120
+ audio = audio[0, :]
121
+ elif audio.shape[1] < audio.shape[0] and audio.shape[1] < 10: # Check if second dim is likely channels
122
+ audio = audio[:, 0]
123
+ else: # Fallback: Average if dims are ambiguous or many channels
124
+ logger.warning(f"Ambiguous audio shape {audio.shape}, averaging channels to mono.")
125
+ audio = np.mean(audio, axis=1 if audio.shape[1] < audio.shape[0] else 0)
 
 
 
 
126
 
127
+ audio_tensor = torch.from_numpy(audio).float()
128
  if target_sr and orig_sr != target_sr:
129
+ if librosa is None: raise RuntimeError("Librosa missing")
130
+ logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
 
131
  audio_np = audio_tensor.numpy()
132
  resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
133
  audio_tensor = torch.from_numpy(resampled_audio_np).float()
134
  current_sr = target_sr
 
135
  else:
136
  current_sr = orig_sr
 
 
137
  return audio_tensor.to(DEVICE), current_sr
 
138
  except Exception as e:
139
  logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
 
140
  cleanup_file(file_path)
141
+ raise HTTPException(status_code=415, detail=f"Could not load/process audio file: {os.path.basename(file_path)}. Check format.")
142
+
143
 
144
  def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
145
  """Saves audio data (Tensor or NumPy array) to a temporary file."""
146
+ # ... (Function definition remains the same) ...
147
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
148
  output_path = os.path.join(TEMP_DIR, output_filename)
149
  try:
150
+ # logger.debug(...) # Keep debug logs if needed
 
 
151
  if isinstance(audio_data, torch.Tensor):
 
 
152
  audio_np = audio_data.detach().cpu().numpy()
153
  elif isinstance(audio_data, np.ndarray):
154
  audio_np = audio_data
155
  else:
156
+ raise TypeError(f"Unsupported audio data type: {type(audio_data)}")
157
 
158
+ if audio_np.dtype != np.float32: audio_np = audio_np.astype(np.float32)
 
 
 
 
 
159
  audio_np = np.clip(audio_np, -1.0, 1.0)
160
 
 
161
  if output_format.lower() in ['wav', 'flac']:
162
  sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
163
  else:
164
+ if audio_np.ndim > 1: audio_np_mono = np.mean(audio_np, axis=0 if audio_np.shape[0] < audio_np.shape[1] else 1) # Basic mono conversion
165
+ else: audio_np_mono = audio_np
 
 
 
 
 
 
 
 
166
  audio_int16 = (audio_np_mono * 32767).astype(np.int16)
167
+ segment = AudioSegment(audio_int16.tobytes(), frame_rate=sampling_rate, sample_width=audio_int16.dtype.itemsize, channels=1)
 
 
 
 
 
168
  segment.export(output_path, format=output_format)
 
169
  return output_path
170
  except Exception as e:
171
  logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
172
+ cleanup_file(output_path)
173
  raise HTTPException(status_code=500, detail="Failed to save processed audio.")
174
 
175
+ # --- Pydub Loading/Exporting (for basic edits - same as before) ---
 
176
  def load_audio_pydub(file_path: str) -> AudioSegment:
177
+ # ... (Function definition remains the same) ...
178
  try:
179
  audio = AudioSegment.from_file(file_path)
 
180
  return audio
181
+ except CouldntDecodeError: raise HTTPException(status_code=415, detail=f"Unsupported audio format (pydub): {os.path.basename(file_path)}")
182
+ except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing audio (pydub): {os.path.basename(file_path)}")
 
 
 
 
 
 
 
183
 
184
  def export_audio_pydub(audio: AudioSegment, format: str) -> str:
185
+ # ... (Function definition remains the same) ...
186
  output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
187
  output_path = os.path.join(TEMP_DIR, output_filename)
188
  try:
 
189
  audio.export(output_path, format=format.lower())
190
  return output_path
191
+ except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to export audio (pydub): {format}")
 
 
 
192
 
193
 
194
+ # --- Synchronous AI Inference Functions (same as before) ---
 
195
  def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
196
+ # ... (Function definition remains the same) ...
197
  if not model: raise ValueError("Enhancement model not loaded")
198
  try:
199
+ logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
200
+ model_device = next(model.parameters()).device
201
+ if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
202
+ if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
 
 
 
 
 
 
203
  with torch.no_grad():
204
  enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
 
 
205
  enhanced_audio = enhanced_tensor.squeeze(0).cpu()
206
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
207
  return enhanced_audio
208
+ except Exception as e: logger.error(f"Sync enhancement error: {e}", exc_info=True); raise
 
 
209
 
210
  def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
211
+ # ... (Function definition remains the same) ...
212
  if not model: raise ValueError("Separation model not loaded")
213
+ if not demucs: raise RuntimeError("Demucs library missing")
214
  try:
215
+ logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
 
 
216
  model_device = next(model.parameters()).device
217
+ if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
218
+ if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
219
+ elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1)
 
 
 
 
 
 
 
220
  if audio_tensor.shape[1] != model.audio_channels:
221
+ if audio_tensor.shape[1] == 1: audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
222
+ else: raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")
 
 
 
 
 
 
 
223
  with torch.no_grad():
224
+ audio_to_process = audio_tensor.squeeze(0)
 
 
 
225
  out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25)
 
 
 
 
 
226
  stem_map = {name: out[i] for i, name in enumerate(model.sources)}
227
+ output_stems = {name: data.mean(dim=0).detach().cpu() for name, data in stem_map.items()}
228
+ logger.info(f"Separation complete. Stems: {list(output_stems.keys())}")
 
 
 
 
 
 
229
  return output_stems
230
+ except Exception as e: logger.error(f"Sync separation error: {e}", exc_info=True); raise
231
 
 
 
 
232
 
233
+ # --- Model Loading Function (Enhanced Logging) ---
234
  def load_hf_models():
235
  """Loads AI models at startup using correct libraries."""
236
+ logger_load = logging.getLogger("ModelLoader") # Use specific logger
237
+ logger_load.setLevel(logging.INFO)
238
+ if not logger_load.handlers: logger_load.addHandler(ch) # Add handler if not already present
239
+
240
  global enhancement_models, separation_models
241
+ if not AI_LIBS_AVAILABLE:
242
+ logger_load.error("Core AI libraries not available. Cannot load AI models.")
243
  return
244
 
245
+ # --- Load Enhancement Model ---
246
  enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
247
+ logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
248
  try:
249
+ # Log device before loading
250
+ logger_load.info(f"Attempting load on device: {DEVICE}")
251
  enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
252
  source=enhancement_model_hparams,
253
  run_opts={"device": DEVICE}
254
  )
255
+ # Check model device after loading
256
+ model_device = next(enhancer.parameters()).device
257
  enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
258
+ logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
259
  except Exception as e:
260
+ logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False) # Log only message
261
+ logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
262
+ logger_load.warning("Enhancement features will be unavailable.")
263
+
264
 
265
+ # --- Load Separation Model ---
266
  separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs"
267
+ logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
268
  try:
269
+ logger_load.info(f"Attempting load on device: {DEVICE}")
270
+ # This automatically handles downloading the model checkpoint
271
  separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
272
+ model_device = next(separator.parameters()).device
273
  separation_models[SEPARATION_MODEL_KEY] = separator
274
+ logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
275
+ logger_load.info(f"Separation model sources: {separator.sources}")
276
  except Exception as e:
277
+ logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
278
+ logger_load.error(f"Traceback: {traceback.format_exc()}")
279
+ logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
280
+ logger_load.warning("Separation features will be unavailable.")
281
+
282
+ logger_load.info(f"--- Model loading attempts finished ---")
283
+ logger_load.info(f"Loaded Enhancement Models: {list(enhancement_models.keys())}")
284
+ logger_load.info(f"Loaded Separation Models: {list(separation_models.keys())}")
285
 
286
 
287
  # --- FastAPI App ---
288
  app = FastAPI(
289
  title="AI Audio Editor API",
290
+ description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
291
+ version="2.1.1", # Incremented version
292
  )
293
 
294
  @app.on_event("startup")
295
  async def startup_event():
296
+ # Use the init logger for startup messages
297
+ logger_init.info("--- FastAPI Application Startup ---")
298
+ if AI_LIBS_AVAILABLE:
299
+ logger_init.info("AI Libraries appear to be available. Proceeding to load models in background thread...")
300
+ # Run blocking model load in thread
301
+ await asyncio.to_thread(load_hf_models)
302
+ logger_init.info("Background model loading task finished (check ModelLoader logs for details).")
303
+ else:
304
+ logger_init.error("AI Libraries failed to import. AI features will be disabled.")
305
+ logger_init.info("--- Startup complete ---")
306
 
307
  # --- API Endpoints ---
308
 
 
311
  """Root endpoint providing a welcome message and available features."""
312
  features = ["/trim", "/concat", "/volume", "/convert"]
313
  ai_features = []
314
+ # Check loaded models dictionary status
315
  if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
316
+ if separation_models:
317
+ model = separation_models.get(SEPARATION_MODEL_KEY)
318
+ sources_str = ', '.join(model.sources) if model else 'N/A'
319
+ ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY}, sources: {sources_str})")
320
 
321
  return {
322
  "message": "Welcome to the AI Audio Editor API.",
323
+ "status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
324
+ "loaded_enhancement_models": list(enhancement_models.keys()),
325
+ "loaded_separation_models": list(separation_models.keys()),
326
  "basic_features": features,
327
  "ai_features": ai_features if ai_features else "None available (check startup logs)",
328
+ "notes": "Requires FFmpeg. AI features require models to load successfully at startup."
329
  }
330
 
 
331
 
332
+ # --- Basic Editing Endpoints ---
333
+ # (Add /trim, /concat, /volume, /convert endpoints here - unchanged)
334
  @app.post("/trim", tags=["Basic Editing"])
335
+ async def trim_audio( background_tasks: BackgroundTasks, file: UploadFile = File(...), start_ms: int = Form(...), end_ms: int = Form(...)):
336
+ if start_ms < 0 or end_ms <= start_ms: raise HTTPException(422, "Invalid start/end times.")
337
+ input_path = await save_upload_file(file, "trim_in_")
338
+ background_tasks.add_task(cleanup_file, input_path); output_path = None
 
 
 
 
 
 
 
 
 
 
 
339
  try:
340
  audio = load_audio_pydub(input_path)
341
  trimmed_audio = audio[start_ms:end_ms]
342
+ fmt = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
343
+ output_path = export_audio_pydub(trimmed_audio, fmt)
344
+ background_tasks.add_task(cleanup_file, output_path)
345
+ fname = f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{fmt}"
346
+ return FileResponse(output_path, media_type=f"audio/{fmt}", filename=fname)
 
 
 
 
 
 
 
 
 
 
347
  except Exception as e:
 
348
  if output_path: cleanup_file(output_path)
349
+ if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Trim error: {e}")
 
350
 
351
  @app.post("/concat", tags=["Basic Editing"])
352
+ async def concatenate_audio( background_tasks: BackgroundTasks, files: List[UploadFile] = File(...), output_format: str = Form("mp3")):
353
+ if len(files) < 2: raise HTTPException(422, "Need at least two files.")
354
+ input_paths, loaded_audios, output_path = [], [], None
 
 
 
 
 
 
 
 
 
 
 
355
  try:
356
+ combined = None
357
  for file in files:
358
+ ip = await save_upload_file(file, "concat_in_")
359
+ input_paths.append(ip); background_tasks.add_task(cleanup_file, ip)
360
+ audio = load_audio_pydub(ip)
361
+ combined = (combined + audio) if combined else audio
362
+ if not combined: raise ValueError("No audio loaded.")
363
+ output_path = export_audio_pydub(combined, output_format)
 
 
 
 
 
 
 
 
 
 
 
364
  background_tasks.add_task(cleanup_file, output_path)
365
+ fname = f"concat_{os.path.splitext(files[0].filename)[0]}_{len(files)-1}_others.{output_format}"
366
+ return FileResponse(output_path, media_type=f"audio/{output_format}", filename=fname)
 
 
 
 
 
 
 
367
  except Exception as e:
368
+ for p in input_paths: cleanup_file(p);
 
 
369
  if output_path: cleanup_file(output_path)
370
+ if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Concat error: {e}")
 
371
 
372
  @app.post("/volume", tags=["Basic Editing"])
373
+ async def change_volume( background_tasks: BackgroundTasks, file: UploadFile = File(...), change_db: float = Form(...)):
374
+ input_path = await save_upload_file(file, "volume_in_")
375
+ background_tasks.add_task(cleanup_file, input_path); output_path = None
 
 
 
 
 
 
 
 
376
  try:
377
  audio = load_audio_pydub(input_path)
378
+ adjusted = audio + change_db
379
+ fmt = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
380
+ output_path = export_audio_pydub(adjusted, fmt)
 
 
 
 
381
  background_tasks.add_task(cleanup_file, output_path)
382
+ fname = f"volume_{change_db}dB_{os.path.splitext(file.filename)[0]}.{fmt}"
383
+ return FileResponse(output_path, media_type=f"audio/{fmt}", filename=fname)
 
 
 
 
 
 
384
  except Exception as e:
 
385
  if output_path: cleanup_file(output_path)
386
+ if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Volume error: {e}")
 
387
 
388
  @app.post("/convert", tags=["Basic Editing"])
389
+ async def convert_format( background_tasks: BackgroundTasks, file: UploadFile = File(...), output_format: str = Form(...)):
390
+ allowed = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus'}
391
+ if output_format.lower() not in allowed: raise HTTPException(422, f"Invalid format. Allowed: {allowed}")
392
+ input_path = await save_upload_file(file, "convert_in_")
393
+ background_tasks.add_task(cleanup_file, input_path); output_path = None
 
 
 
 
 
 
 
 
 
 
394
  try:
 
395
  audio = load_audio_pydub(input_path)
 
396
  output_path = export_audio_pydub(audio, output_format.lower())
397
  background_tasks.add_task(cleanup_file, output_path)
398
+ fname = f"{os.path.splitext(file.filename)[0]}_converted.{output_format.lower()}"
399
+ return FileResponse(output_path, media_type=f"audio/{output_format.lower()}", filename=fname)
 
 
 
 
 
 
 
400
  except Exception as e:
 
401
  if output_path: cleanup_file(output_path)
402
+ if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Convert error: {e}")
 
403
 
404
 
405
+ # --- AI Endpoints (Unchanged Functionality, relies on successful loading) ---
406
 
407
  @app.post("/enhance", tags=["AI Editing"])
408
  async def enhance_speech(
 
412
  output_format: str = Form("wav", description="Output format (wav, flac recommended).")
413
  ):
414
  """Enhances speech audio using a pre-loaded SpeechBrain model."""
415
+ if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
 
416
  if model_key not in enhancement_models:
417
  logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
418
+ raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server startup logs.")
419
 
420
  loaded_model = enhancement_models[model_key]
 
421
  logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
422
  input_path = await save_upload_file(file, prefix="enhance_in_")
423
  background_tasks.add_task(cleanup_file, input_path)
424
  output_path = None
 
425
  try:
 
426
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
 
427
  logger.info("Submitting enhancement task to background thread...")
428
  enhanced_audio_tensor = await asyncio.to_thread(
429
  _run_enhancement_sync, loaded_model, audio_tensor, current_sr
430
  )
431
  logger.info("Enhancement task completed.")
 
 
432
  output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
433
  background_tasks.add_task(cleanup_file, output_path)
 
434
  output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
435
+ return FileResponse(path=output_path, media_type=f"audio/{output_format}", filename=output_filename)
 
 
 
 
436
  except Exception as e:
437
  logger.error(f"Error during enhancement operation: {e}", exc_info=True)
438
+ if output_path: cleanup_file(output_path)
439
+ if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Enhancement error: {e}")
 
440
 
441
 
442
  @app.post("/separate", tags=["AI Editing"])
 
448
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
449
  ):
450
  """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
451
+ if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
 
452
  if model_key not in separation_models:
453
  logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
454
+ raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server startup logs.")
455
 
456
  loaded_model = separation_models[model_key]
457
+ valid_stems = set(loaded_model.sources)
458
  requested_stems = set(s.lower() for s in stems)
459
  if not requested_stems.issubset(valid_stems):
460
+ raise HTTPException(422, f"Invalid stem(s). Model '{model_key}' provides: {valid_stems}")
461
 
462
  logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
463
  input_path = await save_upload_file(file, prefix="separate_in_")
464
  background_tasks.add_task(cleanup_file, input_path)
465
  stem_output_paths: Dict[str, str] = {}
466
+ zip_buffer = io.BytesIO(); zipf = None # Define zipf here
467
 
468
  try:
 
469
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
 
470
  logger.info("Submitting separation task to background thread...")
471
  all_separated_stems_tensors = await asyncio.to_thread(
472
  _run_separation_sync, loaded_model, audio_tensor, current_sr
473
  )
474
  logger.info("Separation task completed.")
475
 
476
+ zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
477
+ for stem_name in requested_stems:
478
+ if stem_name in all_separated_stems_tensors:
479
+ stem_tensor = all_separated_stems_tensors[stem_name]
480
+ stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
481
+ stem_output_paths[stem_name] = stem_path
482
+ background_tasks.add_task(cleanup_file, stem_path) # Cleanup after response sent
483
+ archive_name = f"{stem_name}.{output_format}"
484
+ zipf.write(stem_path, arcname=archive_name)
485
+ logger.info(f"Added '{archive_name}' to ZIP.")
486
+
487
+ zipf.close() # Close zip file BEFORE seeking/reading
 
 
 
 
 
 
 
 
 
488
  zip_buffer.seek(0)
489
 
490
  zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
 
495
  )
496
  except Exception as e:
497
  logger.error(f"Error during separation operation: {e}", exc_info=True)
498
+ # Ensure buffer/zipfile are closed and temp files cleaned up on error
499
+ if zipf: zipf.close() # Ensure zipfile is closed
500
+ if zip_buffer: zip_buffer.close()
501
  for path in stem_output_paths.values(): cleanup_file(path)
 
502
  if isinstance(e, HTTPException): raise e
503
+ else: raise HTTPException(500, f"Separation error: {e}")
504
+
505
+ # --- (How to Run instructions remain the same) ---