Athspi commited on
Commit
2c84da8
·
verified ·
1 Parent(s): 3e135af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +566 -182
app.py CHANGED
@@ -25,8 +25,11 @@ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(messag
25
  ch = logging.StreamHandler()
26
  ch.setLevel(logging.INFO)
27
  ch.setFormatter(formatter)
28
- logger_init.addHandler(ch)
 
 
29
 
 
30
  try:
31
  logger_init.info("Importing torch...")
32
  import torch
@@ -45,189 +48,333 @@ try:
45
  AI_LIBS_AVAILABLE = True
46
  except ImportError as e:
47
  logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
48
- logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed.")
49
  logger_init.error("AI features will be unavailable.")
 
50
  torch = None
51
  sf = None
52
  np = None
53
  librosa = None
54
  speechbrain = None
55
  demucs = None
56
- AI_LIBS_AVAILABLE = False
57
 
58
  # --- Configuration & Setup ---
59
  TEMP_DIR = tempfile.gettempdir()
60
- os.makedirs(TEMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
61
 
62
  # Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
63
- logger = logging.getLogger(__name__) # Will inherit root logger settings
 
64
 
65
  # --- Global Variables for Loaded Models ---
66
  ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
67
- SEPARATION_MODEL_KEY = "htdemucs"
 
68
 
69
  enhancement_models: Dict[str, Any] = {}
70
  separation_models: Dict[str, Any] = {}
71
 
72
- ENHANCEMENT_SR = 16000
73
- DEMUCS_SR = 44100
 
74
 
75
  # --- Device Selection ---
76
- if torch:
77
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
78
- logger_init.info(f"Using device: {DEVICE}")
79
  else:
80
- DEVICE = "cpu"
81
- logger_init.info("Torch not available, defaulting device to CPU.")
82
 
83
 
84
- # --- Helper Functions (cleanup_file, save_upload_file - same as before) ---
 
85
  def cleanup_file(file_path: str):
86
  """Safely remove a file."""
87
  try:
88
- if file_path and os.path.exists(file_path):
89
  os.remove(file_path)
90
  # logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
91
  except Exception as e:
 
92
  logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
93
 
94
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
95
  """Saves an uploaded file to a temporary location and returns the path."""
 
 
 
96
  _, file_extension = os.path.splitext(upload_file.filename)
 
97
  if not file_extension: file_extension = ".wav"
98
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
 
99
  try:
 
100
  with open(temp_file_path, "wb") as buffer:
101
- while content := await upload_file.read(1024 * 1024): buffer.write(content)
102
- logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
 
 
103
  return temp_file_path
104
  except Exception as e:
105
- logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
106
- cleanup_file(temp_file_path)
107
  raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
108
  finally:
109
- await upload_file.close()
 
 
 
 
 
 
 
110
 
111
- # --- Audio Loading/Saving Functions (same as before) ---
112
  def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
113
- """Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
114
- # ... (Function definition remains the same) ...
 
 
 
 
115
  try:
116
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
117
- # logger.debug(...) # Keep debug logs if needed
118
- if audio.ndim > 1: # Ensure mono
119
- if audio.shape[0] < audio.shape[1] and audio.shape[0] < 10: # Check if first dim is likely channels
120
- audio = audio[0, :]
121
- elif audio.shape[1] < audio.shape[0] and audio.shape[1] < 10: # Check if second dim is likely channels
122
- audio = audio[:, 0]
123
- else: # Fallback: Average if dims are ambiguous or many channels
124
- logger.warning(f"Ambiguous audio shape {audio.shape}, averaging channels to mono.")
125
- audio = np.mean(audio, axis=1 if audio.shape[1] < audio.shape[0] else 0)
 
 
 
 
 
 
126
 
 
 
 
 
127
  audio_tensor = torch.from_numpy(audio).float()
 
 
 
128
  if target_sr and orig_sr != target_sr:
129
- if librosa is None: raise RuntimeError("Librosa missing")
130
  logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
 
131
  audio_np = audio_tensor.numpy()
132
- resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
133
  audio_tensor = torch.from_numpy(resampled_audio_np).float()
134
  current_sr = target_sr
135
- else:
136
- current_sr = orig_sr
 
137
  return audio_tensor.to(DEVICE), current_sr
 
 
 
 
 
138
  except Exception as e:
139
- logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
140
  cleanup_file(file_path)
141
- raise HTTPException(status_code=415, detail=f"Could not load/process audio file: {os.path.basename(file_path)}. Check format.")
142
 
143
 
144
  def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
145
  """Saves audio data (Tensor or NumPy array) to a temporary file."""
146
- # ... (Function definition remains the same) ...
 
 
147
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
148
  output_path = os.path.join(TEMP_DIR, output_filename)
149
  try:
150
- # logger.debug(...) # Keep debug logs if needed
 
 
151
  if isinstance(audio_data, torch.Tensor):
 
 
152
  audio_np = audio_data.detach().cpu().numpy()
153
  elif isinstance(audio_data, np.ndarray):
154
  audio_np = audio_data
155
  else:
156
- raise TypeError(f"Unsupported audio data type: {type(audio_data)}")
 
 
 
 
 
157
 
158
- if audio_np.dtype != np.float32: audio_np = audio_np.astype(np.float32)
159
  audio_np = np.clip(audio_np, -1.0, 1.0)
160
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if output_format.lower() in ['wav', 'flac']:
162
  sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
163
  else:
164
- if audio_np.ndim > 1: audio_np_mono = np.mean(audio_np, axis=0 if audio_np.shape[0] < audio_np.shape[1] else 1) # Basic mono conversion
165
- else: audio_np_mono = audio_np
166
- audio_int16 = (audio_np_mono * 32767).astype(np.int16)
167
- segment = AudioSegment(audio_int16.tobytes(), frame_rate=sampling_rate, sample_width=audio_int16.dtype.itemsize, channels=1)
 
 
 
 
 
 
 
 
168
  segment.export(output_path, format=output_format)
 
 
169
  return output_path
170
  except Exception as e:
171
  logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
172
- cleanup_file(output_path)
173
- raise HTTPException(status_code=500, detail="Failed to save processed audio.")
174
 
175
- # --- Pydub Loading/Exporting (for basic edits - same as before) ---
 
176
  def load_audio_pydub(file_path: str) -> AudioSegment:
177
- # ... (Function definition remains the same) ...
 
 
178
  try:
179
- audio = AudioSegment.from_file(file_path)
 
 
 
 
 
 
 
180
  return audio
181
- except CouldntDecodeError: raise HTTPException(status_code=415, detail=f"Unsupported audio format (pydub): {os.path.basename(file_path)}")
182
- except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing audio (pydub): {os.path.basename(file_path)}")
 
 
 
 
 
 
183
 
184
  def export_audio_pydub(audio: AudioSegment, format: str) -> str:
185
- # ... (Function definition remains the same) ...
186
  output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
187
  output_path = os.path.join(TEMP_DIR, output_filename)
188
  try:
 
189
  audio.export(output_path, format=format.lower())
190
  return output_path
191
- except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to export audio (pydub): {format}")
 
 
 
 
192
 
 
193
 
194
- # --- Synchronous AI Inference Functions (same as before) ---
195
  def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
196
- # ... (Function definition remains the same) ...
197
- if not model: raise ValueError("Enhancement model not loaded")
198
  try:
199
  logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
200
- model_device = next(model.parameters()).device
201
  if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
 
202
  if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
 
203
  with torch.no_grad():
204
- enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
 
 
 
 
 
 
 
 
205
  enhanced_audio = enhanced_tensor.squeeze(0).cpu()
206
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
207
  return enhanced_audio
208
- except Exception as e: logger.error(f"Sync enhancement error: {e}", exc_info=True); raise
 
 
209
 
210
  def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
211
- # ... (Function definition remains the same) ...
212
- if not model: raise ValueError("Separation model not loaded")
213
  if not demucs: raise RuntimeError("Demucs library missing")
214
  try:
215
  logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
216
  model_device = next(model.parameters()).device
217
  if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
218
- if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
219
- elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1)
 
 
 
 
220
  if audio_tensor.shape[1] != model.audio_channels:
221
- if audio_tensor.shape[1] == 1: audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
222
- else: raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")
 
 
 
 
 
 
223
  with torch.no_grad():
 
 
224
  audio_to_process = audio_tensor.squeeze(0)
225
- out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25)
 
 
 
 
 
 
226
  stem_map = {name: out[i] for i, name in enumerate(model.sources)}
227
- output_stems = {name: data.mean(dim=0).detach().cpu() for name, data in stem_map.items()}
228
- logger.info(f"Separation complete. Stems: {list(output_stems.keys())}")
 
 
 
 
 
 
229
  return output_stems
230
- except Exception as e: logger.error(f"Sync separation error: {e}", exc_info=True); raise
 
 
 
231
 
232
 
233
  # --- Model Loading Function (Enhanced Logging) ---
@@ -235,29 +382,35 @@ def load_hf_models():
235
  """Loads AI models at startup using correct libraries."""
236
  logger_load = logging.getLogger("ModelLoader") # Use specific logger
237
  logger_load.setLevel(logging.INFO)
238
- if not logger_load.handlers: logger_load.addHandler(ch) # Add handler if not already present
 
239
 
240
  global enhancement_models, separation_models
241
  if not AI_LIBS_AVAILABLE:
242
  logger_load.error("Core AI libraries not available. Cannot load AI models.")
243
  return
244
 
 
 
245
  # --- Load Enhancement Model ---
246
  enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
247
  logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
248
  try:
249
- # Log device before loading
250
  logger_load.info(f"Attempting load on device: {DEVICE}")
 
 
 
251
  enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
252
  source=enhancement_model_hparams,
 
253
  run_opts={"device": DEVICE}
254
  )
255
- # Check model device after loading
256
  model_device = next(enhancer.parameters()).device
257
  enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
258
  logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
 
259
  except Exception as e:
260
- logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False) # Log only message
261
  logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
262
  logger_load.warning("Enhancement features will be unavailable.")
263
 
@@ -267,28 +420,29 @@ def load_hf_models():
267
  logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
268
  try:
269
  logger_load.info(f"Attempting load on device: {DEVICE}")
270
- # This automatically handles downloading the model checkpoint
271
  separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
272
  model_device = next(separator.parameters()).device
273
  separation_models[SEPARATION_MODEL_KEY] = separator
274
  logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
275
- logger_load.info(f"Separation model sources: {separator.sources}")
 
276
  except Exception as e:
277
  logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
278
  logger_load.error(f"Traceback: {traceback.format_exc()}")
279
- logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
280
  logger_load.warning("Separation features will be unavailable.")
281
 
282
  logger_load.info(f"--- Model loading attempts finished ---")
283
- logger_load.info(f"Loaded Enhancement Models: {list(enhancement_models.keys())}")
284
- logger_load.info(f"Loaded Separation Models: {list(separation_models.keys())}")
285
 
286
 
287
  # --- FastAPI App ---
288
  app = FastAPI(
289
  title="AI Audio Editor API",
290
  description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
291
- version="2.1.1", # Incremented version
292
  )
293
 
294
  @app.on_event("startup")
@@ -296,210 +450,440 @@ async def startup_event():
296
  # Use the init logger for startup messages
297
  logger_init.info("--- FastAPI Application Startup ---")
298
  if AI_LIBS_AVAILABLE:
299
- logger_init.info("AI Libraries appear to be available. Proceeding to load models in background thread...")
300
  # Run blocking model load in thread
301
  await asyncio.to_thread(load_hf_models)
302
- logger_init.info("Background model loading task finished (check ModelLoader logs for details).")
303
  else:
304
- logger_init.error("AI Libraries failed to import. AI features will be disabled.")
305
- logger_init.info("--- Startup complete ---")
306
 
307
  # --- API Endpoints ---
308
 
309
  @app.get("/", tags=["General"])
310
  def read_root():
311
- """Root endpoint providing a welcome message and available features."""
312
  features = ["/trim", "/concat", "/volume", "/convert"]
313
- ai_features = []
314
- # Check loaded models dictionary status
315
- if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
316
- if separation_models:
317
- model = separation_models.get(SEPARATION_MODEL_KEY)
318
- sources_str = ', '.join(model.sources) if model else 'N/A'
319
- ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY}, sources: {sources_str})")
 
 
 
 
 
 
 
 
 
 
320
 
321
  return {
322
  "message": "Welcome to the AI Audio Editor API.",
323
  "status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
324
- "loaded_enhancement_models": list(enhancement_models.keys()),
325
- "loaded_separation_models": list(separation_models.keys()),
326
- "basic_features": features,
327
- "ai_features": ai_features if ai_features else "None available (check startup logs)",
328
- "notes": "Requires FFmpeg. AI features require models to load successfully at startup."
329
  }
330
 
331
 
332
  # --- Basic Editing Endpoints ---
333
- # (Add /trim, /concat, /volume, /convert endpoints here - unchanged)
334
  @app.post("/trim", tags=["Basic Editing"])
335
- async def trim_audio( background_tasks: BackgroundTasks, file: UploadFile = File(...), start_ms: int = Form(...), end_ms: int = Form(...)):
336
- if start_ms < 0 or end_ms <= start_ms: raise HTTPException(422, "Invalid start/end times.")
337
- input_path = await save_upload_file(file, "trim_in_")
338
- background_tasks.add_task(cleanup_file, input_path); output_path = None
 
 
 
 
 
 
 
 
 
 
 
 
339
  try:
340
- audio = load_audio_pydub(input_path)
341
  trimmed_audio = audio[start_ms:end_ms]
342
- fmt = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
343
- output_path = export_audio_pydub(trimmed_audio, fmt)
344
- background_tasks.add_task(cleanup_file, output_path)
345
- fname = f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{fmt}"
346
- return FileResponse(output_path, media_type=f"audio/{fmt}", filename=fname)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  except Exception as e:
 
 
348
  if output_path: cleanup_file(output_path)
349
- if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Trim error: {e}")
 
350
 
351
  @app.post("/concat", tags=["Basic Editing"])
352
- async def concatenate_audio( background_tasks: BackgroundTasks, files: List[UploadFile] = File(...), output_format: str = Form("mp3")):
353
- if len(files) < 2: raise HTTPException(422, "Need at least two files.")
354
- input_paths, loaded_audios, output_path = [], [], None
 
 
 
 
 
 
 
 
 
 
355
  try:
356
- combined = None
357
- for file in files:
358
- ip = await save_upload_file(file, "concat_in_")
359
- input_paths.append(ip); background_tasks.add_task(cleanup_file, ip)
360
- audio = load_audio_pydub(ip)
361
- combined = (combined + audio) if combined else audio
362
- if not combined: raise ValueError("No audio loaded.")
363
- output_path = export_audio_pydub(combined, output_format)
364
- background_tasks.add_task(cleanup_file, output_path)
365
- fname = f"concat_{os.path.splitext(files[0].filename)[0]}_{len(files)-1}_others.{output_format}"
366
- return FileResponse(output_path, media_type=f"audio/{output_format}", filename=fname)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  except Exception as e:
368
- for p in input_paths: cleanup_file(p);
 
369
  if output_path: cleanup_file(output_path)
370
- if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Concat error: {e}")
 
371
 
372
  @app.post("/volume", tags=["Basic Editing"])
373
- async def change_volume( background_tasks: BackgroundTasks, file: UploadFile = File(...), change_db: float = Form(...)):
374
- input_path = await save_upload_file(file, "volume_in_")
375
- background_tasks.add_task(cleanup_file, input_path); output_path = None
 
 
 
 
 
 
 
376
  try:
377
  audio = load_audio_pydub(input_path)
378
- adjusted = audio + change_db
379
- fmt = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
380
- output_path = export_audio_pydub(adjusted, fmt)
 
 
 
 
 
 
 
381
  background_tasks.add_task(cleanup_file, output_path)
382
- fname = f"volume_{change_db}dB_{os.path.splitext(file.filename)[0]}.{fmt}"
383
- return FileResponse(output_path, media_type=f"audio/{fmt}", filename=fname)
 
 
 
 
 
 
 
 
 
 
 
 
384
  except Exception as e:
 
385
  if output_path: cleanup_file(output_path)
386
- if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Volume error: {e}")
 
387
 
388
  @app.post("/convert", tags=["Basic Editing"])
389
- async def convert_format( background_tasks: BackgroundTasks, file: UploadFile = File(...), output_format: str = Form(...)):
390
- allowed = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus'}
391
- if output_format.lower() not in allowed: raise HTTPException(422, f"Invalid format. Allowed: {allowed}")
392
- input_path = await save_upload_file(file, "convert_in_")
393
- background_tasks.add_task(cleanup_file, input_path); output_path = None
 
 
 
 
 
 
 
 
 
 
 
394
  try:
 
395
  audio = load_audio_pydub(input_path)
396
- output_path = export_audio_pydub(audio, output_format.lower())
 
 
 
397
  background_tasks.add_task(cleanup_file, output_path)
398
- fname = f"{os.path.splitext(file.filename)[0]}_converted.{output_format.lower()}"
399
- return FileResponse(output_path, media_type=f"audio/{output_format.lower()}", filename=fname)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  except Exception as e:
 
401
  if output_path: cleanup_file(output_path)
402
- if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Convert error: {e}")
403
 
404
 
405
- # --- AI Endpoints (Unchanged Functionality, relies on successful loading) ---
406
 
407
  @app.post("/enhance", tags=["AI Editing"])
408
  async def enhance_speech(
409
  background_tasks: BackgroundTasks,
410
  file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
411
- model_key: str = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use."),
 
412
  output_format: str = Form("wav", description="Output format (wav, flac recommended).")
413
  ):
414
  """Enhances speech audio using a pre-loaded SpeechBrain model."""
415
- if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
416
- if model_key not in enhancement_models:
417
- logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
418
- raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server startup logs.")
 
 
 
 
419
 
420
- loaded_model = enhancement_models[model_key]
421
- logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
422
  input_path = await save_upload_file(file, prefix="enhance_in_")
423
  background_tasks.add_task(cleanup_file, input_path)
424
  output_path = None
425
  try:
 
426
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
 
427
  logger.info("Submitting enhancement task to background thread...")
428
  enhanced_audio_tensor = await asyncio.to_thread(
429
  _run_enhancement_sync, loaded_model, audio_tensor, current_sr
430
  )
431
  logger.info("Enhancement task completed.")
 
 
432
  output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
433
  background_tasks.add_task(cleanup_file, output_path)
 
434
  output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
435
- return FileResponse(path=output_path, media_type=f"audio/{output_format}", filename=output_filename)
 
 
 
 
 
 
436
  except Exception as e:
437
- logger.error(f"Error during enhancement operation: {e}", exc_info=True)
438
  if output_path: cleanup_file(output_path)
439
- if isinstance(e, HTTPException): raise e; else: raise HTTPException(500, f"Enhancement error: {e}")
440
 
441
 
442
  @app.post("/separate", tags=["AI Editing"])
443
  async def separate_sources(
444
  background_tasks: BackgroundTasks,
445
  file: UploadFile = File(..., description="Music audio file to separate into stems."),
446
- model_key: str = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use."),
447
  stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
448
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
449
  ):
450
  """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
451
- if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
452
- if model_key not in separation_models:
453
- logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
454
- raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server startup logs.")
 
455
 
456
- loaded_model = separation_models[model_key]
457
  valid_stems = set(loaded_model.sources)
458
  requested_stems = set(s.lower() for s in stems)
459
- if not requested_stems.issubset(valid_stems):
460
- raise HTTPException(422, f"Invalid stem(s). Model '{model_key}' provides: {valid_stems}")
461
 
462
- logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
 
 
 
 
 
 
 
 
463
  input_path = await save_upload_file(file, prefix="separate_in_")
464
  background_tasks.add_task(cleanup_file, input_path)
465
- stem_output_paths: Dict[str, str] = {}
466
- zip_buffer = io.BytesIO(); zipf = None # Define zipf here
467
 
468
  try:
 
469
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
 
470
  logger.info("Submitting separation task to background thread...")
471
  all_separated_stems_tensors = await asyncio.to_thread(
472
  _run_separation_sync, loaded_model, audio_tensor, current_sr
473
  )
474
- logger.info("Separation task completed.")
475
 
 
 
476
  zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
 
477
  for stem_name in requested_stems:
478
  if stem_name in all_separated_stems_tensors:
479
  stem_tensor = all_separated_stems_tensors[stem_name]
480
- stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
481
- stem_output_paths[stem_name] = stem_path
482
- background_tasks.add_task(cleanup_file, stem_path) # Cleanup after response sent
483
- archive_name = f"{stem_name}.{output_format}"
484
- zipf.write(stem_path, arcname=archive_name)
485
- logger.info(f"Added '{archive_name}' to ZIP.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  zipf.close() # Close zip file BEFORE seeking/reading
488
- zip_buffer.seek(0)
489
 
490
- zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
 
 
 
 
 
 
 
 
491
  return StreamingResponse(
492
- zip_buffer,
493
  media_type="application/zip",
494
  headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
495
  )
496
- except Exception as e:
497
- logger.error(f"Error during separation operation: {e}", exc_info=True)
498
- # Ensure buffer/zipfile are closed and temp files cleaned up on error
499
  if zipf: zipf.close() # Ensure zipfile is closed
500
  if zip_buffer: zip_buffer.close()
 
 
 
 
 
 
501
  for path in stem_output_paths.values(): cleanup_file(path)
502
- if isinstance(e, HTTPException): raise e
503
- else: raise HTTPException(500, f"Separation error: {e}")
504
-
505
- # --- (How to Run instructions remain the same) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ch = logging.StreamHandler()
26
  ch.setLevel(logging.INFO)
27
  ch.setFormatter(formatter)
28
+ # Avoid adding handler multiple times if script reloads
29
+ if not logger_init.handlers:
30
+ logger_init.addHandler(ch)
31
 
32
+ AI_LIBS_AVAILABLE = False
33
  try:
34
  logger_init.info("Importing torch...")
35
  import torch
 
48
  AI_LIBS_AVAILABLE = True
49
  except ImportError as e:
50
  logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
51
+ logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed correctly.")
52
  logger_init.error("AI features will be unavailable.")
53
+ # Define placeholders so the rest of the code doesn't break completely on import error
54
  torch = None
55
  sf = None
56
  np = None
57
  librosa = None
58
  speechbrain = None
59
  demucs = None
 
60
 
61
  # --- Configuration & Setup ---
62
  TEMP_DIR = tempfile.gettempdir()
63
+ # Attempt to create temp dir if it doesn't exist (useful in some environments)
64
+ try:
65
+ os.makedirs(TEMP_DIR, exist_ok=True)
66
+ except OSError as e:
67
+ logger_init.error(f"Could not create temporary directory {TEMP_DIR}: {e}")
68
+ # Fallback or raise an error depending on desired behavior
69
+ TEMP_DIR = "." # Use current directory as fallback (less ideal)
70
+ logger_init.warning(f"Using current directory '{TEMP_DIR}' for temporary files.")
71
+
72
 
73
  # Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
74
+ # This logger will be used by endpoint handlers
75
+ logger = logging.getLogger(__name__)
76
 
77
  # --- Global Variables for Loaded Models ---
78
  ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
79
+ # Choose a default Demucs model (htdemucs is good quality)
80
+ SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
81
 
82
  enhancement_models: Dict[str, Any] = {}
83
  separation_models: Dict[str, Any] = {}
84
 
85
+ # Target sampling rates (confirm from model specifics if necessary)
86
+ ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
87
+ DEMUCS_SR = 44100 # Demucs default is 44.1kHz
88
 
89
  # --- Device Selection ---
90
+ if AI_LIBS_AVAILABLE and torch:
91
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
92
+ logger_init.info(f"Selected device for AI models: {DEVICE}")
93
  else:
94
+ DEVICE = "cpu" # Fallback if torch failed import
95
+ logger_init.info("Torch not available or AI libs failed import, defaulting device to CPU.")
96
 
97
 
98
+ # --- Helper Functions ---
99
+
100
  def cleanup_file(file_path: str):
101
  """Safely remove a file."""
102
  try:
103
+ if file_path and isinstance(file_path, str) and os.path.exists(file_path):
104
  os.remove(file_path)
105
  # logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
106
  except Exception as e:
107
+ # Log error but don't crash the cleanup process for other files
108
  logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
109
 
110
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
111
  """Saves an uploaded file to a temporary location and returns the path."""
112
+ if not upload_file or not upload_file.filename:
113
+ raise HTTPException(status_code=400, detail="Invalid file upload object.")
114
+
115
  _, file_extension = os.path.splitext(upload_file.filename)
116
+ # Default to .wav if no extension, as it's widely compatible for loading
117
  if not file_extension: file_extension = ".wav"
118
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
119
+
120
  try:
121
+ logger.debug(f"Attempting to save uploaded file to: {temp_file_path}")
122
  with open(temp_file_path, "wb") as buffer:
123
+ # Read chunk by chunk for large files
124
+ while content := await upload_file.read(1024 * 1024): # 1MB chunks
125
+ buffer.write(content)
126
+ logger.info(f"Saved uploaded file '{upload_file.filename}' ({upload_file.content_type}) to temp path: {temp_file_path}")
127
  return temp_file_path
128
  except Exception as e:
129
+ logger.error(f"Failed to save uploaded file '{upload_file.filename}' to {temp_file_path}: {e}", exc_info=True)
130
+ cleanup_file(temp_file_path) # Attempt cleanup if saving failed
131
  raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
132
  finally:
133
+ # Ensure file is closed even if saving fails mid-way
134
+ try:
135
+ await upload_file.close()
136
+ except Exception:
137
+ pass # Ignore errors during close if already failed
138
+
139
+
140
+ # --- Audio Loading/Saving Functions ---
141
 
 
142
  def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
143
+ """Loads audio using soundfile, converts to mono float32 Torch tensor, optionally resamples."""
144
+ if not AI_LIBS_AVAILABLE:
145
+ raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
146
+ if not os.path.exists(file_path):
147
+ raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found at {file_path}")
148
+
149
  try:
150
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
151
+ logger.info(f"Loaded '{os.path.basename(file_path)}' - SR={orig_sr}, Shape={audio.shape}, dtype={audio.dtype}")
152
+
153
+ # Ensure mono
154
+ if audio.ndim > 1:
155
+ # Check which dimension is smaller (likely channels)
156
+ channel_dim = np.argmin(audio.shape)
157
+ if audio.shape[channel_dim] > 1 and audio.shape[channel_dim] < 10: # Heuristic: <10 channels
158
+ logger.info(f"Detected {audio.shape[channel_dim]} channels. Converting to mono by averaging axis {channel_dim}.")
159
+ audio = np.mean(audio, axis=channel_dim)
160
+ else: # Fallback or if shape is ambiguous (e.g., very short stereo)
161
+ logger.warning(f"Audio has shape {audio.shape}. Taking first channel/element assuming mono or channel-first.")
162
+ audio = audio[0] if channel_dim == 0 else audio[:, 0] # Select first index of the likely channel dimension
163
+
164
+ logger.debug(f"Shape after mono conversion: {audio.shape}")
165
+
166
 
167
+ # Ensure it's now 1D
168
+ audio = audio.flatten()
169
+
170
+ # Convert numpy array to torch tensor
171
  audio_tensor = torch.from_numpy(audio).float()
172
+
173
+ # Resample if necessary using librosa
174
+ current_sr = orig_sr
175
  if target_sr and orig_sr != target_sr:
176
+ if librosa is None: raise RuntimeError("Librosa missing for resampling")
177
  logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
178
+ # Librosa works on numpy
179
  audio_np = audio_tensor.numpy()
180
+ resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr, res_type='kaiser_best') # Specify resampling type
181
  audio_tensor = torch.from_numpy(resampled_audio_np).float()
182
  current_sr = target_sr
183
+ logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
184
+
185
+ # Ensure tensor is on the correct device
186
  return audio_tensor.to(DEVICE), current_sr
187
+
188
+ except sf.SoundFileError as sf_err:
189
+ logger.error(f"SoundFileError loading {file_path}: {sf_err}", exc_info=True)
190
+ cleanup_file(file_path)
191
+ raise HTTPException(status_code=415, detail=f"Could not decode audio file: {os.path.basename(file_path)}. Unsupported format or corrupt file. Error: {sf_err}")
192
  except Exception as e:
193
+ logger.error(f"Unexpected error loading/processing audio file {file_path} for AI: {e}", exc_info=True)
194
  cleanup_file(file_path)
195
+ raise HTTPException(status_code=500, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Check server logs.")
196
 
197
 
198
  def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
199
  """Saves audio data (Tensor or NumPy array) to a temporary file."""
200
+ if not AI_LIBS_AVAILABLE:
201
+ raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
202
+
203
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
204
  output_path = os.path.join(TEMP_DIR, output_filename)
205
  try:
206
+ logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
207
+
208
+ # Convert tensor to numpy array if needed
209
  if isinstance(audio_data, torch.Tensor):
210
+ logger.debug("Converting output tensor to NumPy array.")
211
+ # Ensure tensor is on CPU before converting to numpy
212
  audio_np = audio_data.detach().cpu().numpy()
213
  elif isinstance(audio_data, np.ndarray):
214
  audio_np = audio_data
215
  else:
216
+ raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
217
+
218
+ # Ensure data is float32
219
+ if audio_np.dtype != np.float32:
220
+ logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
221
+ audio_np = audio_np.astype(np.float32)
222
 
223
+ # Clip values to avoid potential issues with formats expecting [-1, 1]
224
  audio_np = np.clip(audio_np, -1.0, 1.0)
225
 
226
+ # Ensure audio is 1D (mono) before saving with soundfile or pydub conversion
227
+ if audio_np.ndim > 1:
228
+ logger.warning(f"Output audio data has {audio_np.ndim} dimensions, attempting to flatten or take first dimension.")
229
+ # Try averaging channels if shape suggests stereo/multi-channel
230
+ channel_dim = np.argmin(audio_np.shape)
231
+ if audio_np.shape[channel_dim] > 1 and audio_np.shape[channel_dim] < 10:
232
+ audio_np = np.mean(audio_np, axis=channel_dim)
233
+ else: # Otherwise just flatten
234
+ audio_np = audio_np.flatten()
235
+
236
+
237
+ # Use soundfile (preferred for wav/flac)
238
  if output_format.lower() in ['wav', 'flac']:
239
  sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
240
  else:
241
+ # For lossy formats, use pydub
242
+ logger.debug(f"Using pydub to export to lossy format: {output_format}")
243
+ # Scale float32 [-1, 1] to int16 for pydub
244
+ audio_int16 = (audio_np * 32767).astype(np.int16)
245
+ segment = AudioSegment(
246
+ audio_int16.tobytes(),
247
+ frame_rate=sampling_rate,
248
+ sample_width=audio_int16.dtype.itemsize,
249
+ channels=1 # Assuming mono after processing above
250
+ )
251
+ # Pydub might need explicit ffmpeg path in some envs
252
+ # AudioSegment.converter = "/path/to/ffmpeg" # Uncomment and set path if needed
253
  segment.export(output_path, format=output_format)
254
+
255
+ logger.info(f"Successfully saved AI audio to {output_path}")
256
  return output_path
257
  except Exception as e:
258
  logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
259
+ cleanup_file(output_path) # Attempt cleanup on saving failure
260
+ raise HTTPException(status_code=500, detail=f"Failed to save processed audio to format '{output_format}'.")
261
 
262
+
263
+ # --- Pydub Loading/Exporting (for basic edits) ---
264
  def load_audio_pydub(file_path: str) -> AudioSegment:
265
+ """Loads an audio file using pydub."""
266
+ if not os.path.exists(file_path):
267
+ raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found (pydub) at {file_path}")
268
  try:
269
+ logger.debug(f"Loading audio with pydub: {file_path}")
270
+ # Explicitly provide format if possible, helps pydub sometimes
271
+ file_ext = os.path.splitext(file_path)[1][1:].lower()
272
+ if file_ext:
273
+ audio = AudioSegment.from_file(file_path, format=file_ext)
274
+ else:
275
+ audio = AudioSegment.from_file(file_path) # Let pydub detect
276
+ logger.info(f"Loaded audio using pydub from: {file_path}")
277
  return audio
278
+ except CouldntDecodeError as e:
279
+ logger.warning(f"Pydub CouldntDecodeError for {file_path}: {e}")
280
+ cleanup_file(file_path)
281
+ raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}")
282
+ except Exception as e:
283
+ logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True)
284
+ cleanup_file(file_path)
285
+ raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}")
286
 
287
  def export_audio_pydub(audio: AudioSegment, format: str) -> str:
288
+ """Exports a Pydub AudioSegment to a temporary file and returns the path."""
289
  output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
290
  output_path = os.path.join(TEMP_DIR, output_filename)
291
  try:
292
+ logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}")
293
  audio.export(output_path, format=format.lower())
294
  return output_path
295
+ except Exception as e:
296
+ logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True)
297
+ cleanup_file(output_path) # Cleanup if export failed
298
+ raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.")
299
+
300
 
301
+ # --- Synchronous AI Inference Functions ---
302
 
 
303
  def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
304
+ """Synchronous wrapper for SpeechBrain enhancement model inference."""
305
+ if not AI_LIBS_AVAILABLE or not model: raise ValueError("Enhancement model/libs not available")
306
  try:
307
  logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
308
+ model_device = next(model.parameters()).device # Check model's current device
309
  if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
310
+ # Add batch dimension if model expects it (most do)
311
  if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
312
+
313
  with torch.no_grad():
314
+ # Check if model expects lengths parameter
315
+ enhance_method = getattr(model, "enhance_batch", getattr(model, "forward", None))
316
+ if "lengths" in enhance_method.__code__.co_varnames:
317
+ enhanced_tensor = enhance_method(audio_tensor, lengths=torch.tensor([audio_tensor.shape[-1]]).to(model_device))
318
+ else:
319
+ enhanced_tensor = enhance_method(audio_tensor)
320
+
321
+
322
+ # Remove batch dimension from output before returning, move back to CPU
323
  enhanced_audio = enhanced_tensor.squeeze(0).cpu()
324
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
325
  return enhanced_audio
326
+ except Exception as e:
327
+ logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
328
+ raise # Re-raise to be caught by the async wrapper
329
 
330
  def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
331
+ """Synchronous wrapper for Demucs source separation model inference."""
332
+ if not AI_LIBS_AVAILABLE or not model: raise ValueError("Separation model/libs not available")
333
  if not demucs: raise RuntimeError("Demucs library missing")
334
  try:
335
  logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
336
  model_device = next(model.parameters()).device
337
  if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
338
+
339
+ # Demucs expects audio as (batch, channels, samples)
340
+ if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
341
+ elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
342
+
343
+ # Repeat channel if model expects stereo but input is mono
344
  if audio_tensor.shape[1] != model.audio_channels:
345
+ if audio_tensor.shape[1] == 1:
346
+ logger.info(f"Model expects {model.audio_channels} channels, input is mono. Repeating channel.")
347
+ audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
348
+ else:
349
+ raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")
350
+
351
+ logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
352
+
353
  with torch.no_grad():
354
+ # Use demucs.apply.apply_model for handling chunking etc.
355
+ # Requires input shape (channels, samples) - process first batch item
356
  audio_to_process = audio_tensor.squeeze(0)
357
+ # Note: shifts=1, split=True are common defaults for quality
358
+ out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25, progress=False) # Disable progress bar in logs
359
+ # Output shape (stems, channels, samples)
360
+
361
+ logger.debug(f"Raw separated sources tensor shape: {out.shape}")
362
+
363
+ # Map stems based on the model's sources list
364
  stem_map = {name: out[i] for i, name in enumerate(model.sources)}
365
+
366
+ # Convert back to mono for simplicity (average channels) and move to CPU
367
+ output_stems = {}
368
+ for name, data in stem_map.items():
369
+ # Average channels, detach, move to CPU
370
+ output_stems[name] = data.mean(dim=0).detach().cpu()
371
+
372
+ logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
373
  return output_stems
374
+
375
+ except Exception as e:
376
+ logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
377
+ raise
378
 
379
 
380
  # --- Model Loading Function (Enhanced Logging) ---
 
382
  """Loads AI models at startup using correct libraries."""
383
  logger_load = logging.getLogger("ModelLoader") # Use specific logger
384
  logger_load.setLevel(logging.INFO)
385
+ # Ensure handler is attached if logger is newly created
386
+ if not logger_load.handlers and ch: logger_load.addHandler(ch)
387
 
388
  global enhancement_models, separation_models
389
  if not AI_LIBS_AVAILABLE:
390
  logger_load.error("Core AI libraries not available. Cannot load AI models.")
391
  return
392
 
393
+ load_success_flags = {"enhancement": False, "separation": False}
394
+
395
  # --- Load Enhancement Model ---
396
  enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
397
  logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
398
  try:
 
399
  logger_load.info(f"Attempting load on device: {DEVICE}")
400
+ # Consider adding savedir if cache issues arise in HF Spaces
401
+ # savedir_sb = os.path.join(TEMP_DIR, "speechbrain_models")
402
+ # os.makedirs(savedir_sb, exist_ok=True)
403
  enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
404
  source=enhancement_model_hparams,
405
+ # savedir=savedir_sb,
406
  run_opts={"device": DEVICE}
407
  )
 
408
  model_device = next(enhancer.parameters()).device
409
  enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
410
  logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
411
+ load_success_flags["enhancement"] = True
412
  except Exception as e:
413
+ logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False)
414
  logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
415
  logger_load.warning("Enhancement features will be unavailable.")
416
 
 
420
  logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
421
  try:
422
  logger_load.info(f"Attempting load on device: {DEVICE}")
423
+ # This automatically handles downloading the model checkpoint via demucs package
424
  separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
425
  model_device = next(separator.parameters()).device
426
  separation_models[SEPARATION_MODEL_KEY] = separator
427
  logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
428
+ logger_load.info(f"Separation model available sources: {separator.sources}")
429
+ load_success_flags["separation"] = True
430
  except Exception as e:
431
  logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
432
  logger_load.error(f"Traceback: {traceback.format_exc()}")
433
+ logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs). Check resource limits (RAM).")
434
  logger_load.warning("Separation features will be unavailable.")
435
 
436
  logger_load.info(f"--- Model loading attempts finished ---")
437
+ logger_load.info(f"Enhancement Model Loaded: {load_success_flags['enhancement']}")
438
+ logger_load.info(f"Separation Model Loaded: {load_success_flags['separation']}")
439
 
440
 
441
  # --- FastAPI App ---
442
  app = FastAPI(
443
  title="AI Audio Editor API",
444
  description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
445
+ version="2.1.2", # Incremented version
446
  )
447
 
448
  @app.on_event("startup")
 
450
  # Use the init logger for startup messages
451
  logger_init.info("--- FastAPI Application Startup ---")
452
  if AI_LIBS_AVAILABLE:
453
+ logger_init.info("AI Libraries imported successfully. Loading models in background thread...")
454
  # Run blocking model load in thread
455
  await asyncio.to_thread(load_hf_models)
456
+ logger_init.info("Background model loading task finished (check ModelLoader logs above for details).")
457
  else:
458
+ logger_init.error("AI Libraries failed to import during init. AI features will be disabled.")
459
+ logger_init.info("--- Startup sequence complete ---")
460
 
461
  # --- API Endpoints ---
462
 
463
  @app.get("/", tags=["General"])
464
  def read_root():
465
+ """Root endpoint providing a welcome message and status of loaded models."""
466
  features = ["/trim", "/concat", "/volume", "/convert"]
467
+ ai_features_status = {}
468
+
469
+ if AI_LIBS_AVAILABLE:
470
+ if enhancement_models:
471
+ ai_features_status[ENHANCEMENT_MODEL_KEY] = "Loaded"
472
+ else:
473
+ ai_features_status[ENHANCEMENT_MODEL_KEY] = "Failed to load (check startup logs)"
474
+
475
+ if separation_models:
476
+ model = separation_models.get(SEPARATION_MODEL_KEY)
477
+ sources_str = ', '.join(model.sources) if model else 'N/A'
478
+ ai_features_status[SEPARATION_MODEL_KEY] = f"Loaded (Sources: {sources_str})"
479
+ else:
480
+ ai_features_status[SEPARATION_MODEL_KEY] = "Failed to load (check startup logs)"
481
+ else:
482
+ ai_features_status["AI Status"] = "Libraries Failed Import"
483
+
484
 
485
  return {
486
  "message": "Welcome to the AI Audio Editor API.",
487
  "status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
488
+ "ai_models_status": ai_features_status,
489
+ "basic_endpoints": features,
490
+ "notes": "Requires FFmpeg. AI features require successful model loading at startup."
 
 
491
  }
492
 
493
 
494
  # --- Basic Editing Endpoints ---
495
+
496
  @app.post("/trim", tags=["Basic Editing"])
497
+ async def trim_audio(
498
+ background_tasks: BackgroundTasks,
499
+ file: UploadFile = File(..., description="Audio file to trim."),
500
+ start_ms: int = Form(..., ge=0, description="Start time in milliseconds."),
501
+ end_ms: int = Form(..., gt=0, description="End time in milliseconds.") # Ensure end > 0
502
+ ):
503
+ """Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub."""
504
+ if end_ms <= start_ms:
505
+ raise HTTPException(status_code=422, detail="End time (end_ms) must be greater than start time (start_ms).")
506
+
507
+ logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
508
+ input_path = await save_upload_file(file, prefix="trim_in_")
509
+ # Schedule cleanup immediately after saving, even if loading fails later
510
+ background_tasks.add_task(cleanup_file, input_path)
511
+ output_path = None # Define before try block
512
+
513
  try:
514
+ audio = load_audio_pydub(input_path) # Can raise HTTPException
515
  trimmed_audio = audio[start_ms:end_ms]
516
+ logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
517
+
518
+ # Determine original format for export
519
+ original_format = os.path.splitext(file.filename)[1][1:].lower()
520
+ # Use mp3 as default only if no extension or if it's 'tmp' etc.
521
+ if not original_format or len(original_format) > 5: # Basic check for valid extension length
522
+ original_format = "mp3"
523
+ logger.warning(f"Using default export format 'mp3' for input '{file.filename}'")
524
+
525
+ output_path = export_audio_pydub(trimmed_audio, original_format) # Can raise HTTPException
526
+ background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
527
+
528
+ # Create a more informative filename
529
+ output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}"
530
+
531
+ return FileResponse(
532
+ path=output_path,
533
+ media_type=f"audio/{original_format}", # Best guess for media type
534
+ filename=output_filename
535
+ )
536
+ except HTTPException as http_exc:
537
+ # If load/export raised HTTPException, re-raise it
538
+ # Cleanup might have already been scheduled, background tasks handle errors
539
+ logger.error(f"HTTP Exception during trim: {http_exc.detail}")
540
+ if output_path: cleanup_file(output_path) # Try immediate cleanup if output exists
541
+ raise http_exc
542
  except Exception as e:
543
+ # Catch other unexpected errors during trimming logic
544
+ logger.error(f"Unexpected error during trim operation: {e}", exc_info=True)
545
  if output_path: cleanup_file(output_path)
546
+ raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during trimming: {str(e)}")
547
+
548
 
549
  @app.post("/concat", tags=["Basic Editing"])
550
+ async def concatenate_audio(
551
+ background_tasks: BackgroundTasks,
552
+ files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
553
+ output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
554
+ ):
555
+ """Concatenates two or more audio files sequentially using Pydub."""
556
+ if len(files) < 2:
557
+ raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
558
+
559
+ logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
560
+ input_paths = [] # Keep track of all saved input file paths
561
+ output_path = None # Define before try block
562
+
563
  try:
564
+ combined_audio: Optional[AudioSegment] = None
565
+ for i, file in enumerate(files):
566
+ if not file or not file.filename:
567
+ logger.warning(f"Skipping invalid file upload at index {i}.")
568
+ continue # Skip potentially empty file entries
569
+
570
+ input_path = await save_upload_file(file, prefix=f"concat_{i}_in_")
571
+ input_paths.append(input_path)
572
+ # Schedule cleanup for this specific input file immediately
573
+ background_tasks.add_task(cleanup_file, input_path)
574
+
575
+ try:
576
+ audio = load_audio_pydub(input_path)
577
+ if combined_audio is None:
578
+ combined_audio = audio
579
+ logger.info(f"Starting concatenation with '{file.filename}' ({len(combined_audio)}ms)")
580
+ else:
581
+ logger.info(f"Adding '{file.filename}' ({len(audio)}ms)")
582
+ combined_audio += audio
583
+ except HTTPException as load_exc:
584
+ # Log error but continue trying to load other files if possible
585
+ logger.error(f"Failed to load file '{file.filename}' for concatenation: {load_exc.detail}. Skipping this file.")
586
+ except Exception as load_exc:
587
+ logger.error(f"Unexpected error loading file '{file.filename}' for concatenation: {load_exc}. Skipping this file.", exc_info=True)
588
+
589
+
590
+ if combined_audio is None:
591
+ raise HTTPException(status_code=400, detail="No valid audio files could be loaded and combined.")
592
+
593
+ logger.info(f"Final concatenated audio length: {len(combined_audio)}ms")
594
+
595
+ output_path = export_audio_pydub(combined_audio, output_format) # Can raise HTTPException
596
+ background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
597
+
598
+ # Determine a reasonable output filename
599
+ first_valid_filename = files[0].filename if files and files[0] else "audio"
600
+ first_filename_base = os.path.splitext(first_valid_filename)[0]
601
+ output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
602
+
603
+ return FileResponse(
604
+ path=output_path,
605
+ media_type=f"audio/{output_format}",
606
+ filename=output_filename
607
+ )
608
+ except HTTPException as http_exc:
609
+ # If load/export raised HTTPException, re-raise it
610
+ logger.error(f"HTTP Exception during concat: {http_exc.detail}")
611
+ # Cleanup for output path, inputs are handled by background tasks
612
+ if output_path: cleanup_file(output_path)
613
+ raise http_exc
614
  except Exception as e:
615
+ # Catch other unexpected errors during combining logic
616
+ logger.error(f"Unexpected error during concat operation: {e}", exc_info=True)
617
  if output_path: cleanup_file(output_path)
618
+ raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during concatenation: {str(e)}")
619
+
620
 
621
  @app.post("/volume", tags=["Basic Editing"])
622
+ async def change_volume(
623
+ background_tasks: BackgroundTasks,
624
+ file: UploadFile = File(..., description="Audio file to adjust volume for."),
625
+ change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.")
626
+ ):
627
+ """Adjusts the volume of an audio file by a specified decibel amount using Pydub."""
628
+ logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
629
+ input_path = await save_upload_file(file, prefix="volume_in_")
630
+ background_tasks.add_task(cleanup_file, input_path)
631
+ output_path = None
632
  try:
633
  audio = load_audio_pydub(input_path)
634
+ # Check for potential silence before applying gain
635
+ if audio.dBFS == -float('inf'):
636
+ logger.warning(f"Input file '{file.filename}' appears to be silent. Applying volume change may have no effect.")
637
+ adjusted_audio = audio + change_db
638
+ logger.info(f"Volume adjusted by {change_db}dB.")
639
+
640
+ original_format = os.path.splitext(file.filename)[1][1:].lower()
641
+ if not original_format or len(original_format) > 5: original_format = "mp3"
642
+
643
+ output_path = export_audio_pydub(adjusted_audio, original_format)
644
  background_tasks.add_task(cleanup_file, output_path)
645
+
646
+ # Create filename
647
+ sign = "+" if change_db >= 0 else ""
648
+ output_filename=f"volume_{sign}{change_db}dB_{os.path.splitext(file.filename)[0]}.{original_format}"
649
+
650
+ return FileResponse(
651
+ path=output_path,
652
+ media_type=f"audio/{original_format}",
653
+ filename=output_filename
654
+ )
655
+ except HTTPException as http_exc:
656
+ logger.error(f"HTTP Exception during volume change: {http_exc.detail}")
657
+ if output_path: cleanup_file(output_path)
658
+ raise http_exc
659
  except Exception as e:
660
+ logger.error(f"Unexpected error during volume operation: {e}", exc_info=True)
661
  if output_path: cleanup_file(output_path)
662
+ raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during volume adjustment: {str(e)}")
663
+
664
 
665
  @app.post("/convert", tags=["Basic Editing"])
666
+ async def convert_format(
667
+ background_tasks: BackgroundTasks,
668
+ file: UploadFile = File(..., description="Audio file to convert."),
669
+ output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
670
+ ):
671
+ """Converts an audio file to a different format using Pydub."""
672
+ # Define allowed formats explicitly
673
+ allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus', 'wma', 'aiff'} # Expanded list
674
+ output_format_lower = output_format.lower()
675
+ if output_format_lower not in allowed_formats:
676
+ raise HTTPException(status_code=422, detail=f"Invalid output format '{output_format}'. Allowed: {', '.join(sorted(list(allowed_formats)))}")
677
+
678
+ logger.info(f"Convert request: file='{file.filename}', output_format='{output_format_lower}'")
679
+ input_path = await save_upload_file(file, prefix="convert_in_")
680
+ background_tasks.add_task(cleanup_file, input_path)
681
+ output_path = None
682
  try:
683
+ # Load using pydub, which handles many input formats
684
  audio = load_audio_pydub(input_path)
685
+ logger.info(f"Successfully loaded '{file.filename}' for conversion.")
686
+
687
+ # Export using pydub
688
+ output_path = export_audio_pydub(audio, output_format_lower)
689
  background_tasks.add_task(cleanup_file, output_path)
690
+ logger.info(f"Successfully exported to {output_format_lower}")
691
+
692
+ # Construct new filename
693
+ filename_base = os.path.splitext(file.filename)[0]
694
+ output_filename = f"{filename_base}_converted.{output_format_lower}"
695
+
696
+ # Determine media type (MIME type) - might need refinement for less common types
697
+ media_type_map = {
698
+ 'mp3': 'audio/mpeg', 'wav': 'audio/wav', 'ogg': 'audio/ogg',
699
+ 'flac': 'audio/flac', 'aac': 'audio/aac', 'm4a': 'audio/mp4', # m4a often uses mp4 container
700
+ 'opus': 'audio/opus', 'wma':'audio/x-ms-wma', 'aiff':'audio/aiff'
701
+ }
702
+ media_type = media_type_map.get(output_format_lower, 'application/octet-stream') # Default binary if unknown
703
+
704
+ return FileResponse(
705
+ path=output_path,
706
+ media_type=media_type,
707
+ filename=output_filename
708
+ )
709
+ except HTTPException as http_exc:
710
+ logger.error(f"HTTP Exception during conversion: {http_exc.detail}")
711
+ if output_path: cleanup_file(output_path)
712
+ raise http_exc
713
  except Exception as e:
714
+ logger.error(f"Unexpected error during convert operation: {e}", exc_info=True)
715
  if output_path: cleanup_file(output_path)
716
+ raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during format conversion: {str(e)}")
717
 
718
 
719
+ # --- AI Endpoints ---
720
 
721
  @app.post("/enhance", tags=["AI Editing"])
722
  async def enhance_speech(
723
  background_tasks: BackgroundTasks,
724
  file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
725
+ # Keep model_key optional for now, assumes default if only one loaded
726
+ model_key: Optional[str] = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use (defaults to primary)."),
727
  output_format: str = Form("wav", description="Output format (wav, flac recommended).")
728
  ):
729
  """Enhances speech audio using a pre-loaded SpeechBrain model."""
730
+ if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
731
+ # Use the provided key or the default
732
+ actual_model_key = model_key or ENHANCEMENT_MODEL_KEY
733
+ if actual_model_key not in enhancement_models:
734
+ logger.error(f"Enhancement model key '{actual_model_key}' requested but model not loaded.")
735
+ raise HTTPException(status_code=503, detail=f"Enhancement model '{actual_model_key}' is not loaded or available. Check server startup logs.")
736
+
737
+ loaded_model = enhancement_models[actual_model_key]
738
 
739
+ logger.info(f"Enhance request: file='{file.filename}', model='{actual_model_key}', format='{output_format}'")
 
740
  input_path = await save_upload_file(file, prefix="enhance_in_")
741
  background_tasks.add_task(cleanup_file, input_path)
742
  output_path = None
743
  try:
744
+ # Load audio as tensor, ensure correct SR (16kHz)
745
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
746
+
747
  logger.info("Submitting enhancement task to background thread...")
748
  enhanced_audio_tensor = await asyncio.to_thread(
749
  _run_enhancement_sync, loaded_model, audio_tensor, current_sr
750
  )
751
  logger.info("Enhancement task completed.")
752
+
753
+ # Save the result (tensor output from enhancer at 16kHz)
754
  output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
755
  background_tasks.add_task(cleanup_file, output_path)
756
+
757
  output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
758
+ media_type = f"audio/{output_format}" # Basic media type
759
+ return FileResponse(path=output_path, media_type=media_type, filename=output_filename)
760
+
761
+ except HTTPException as http_exc:
762
+ logger.error(f"HTTP Exception during enhancement: {http_exc.detail}")
763
+ if output_path: cleanup_file(output_path)
764
+ raise http_exc
765
  except Exception as e:
766
+ logger.error(f"Unexpected error during enhancement operation: {e}", exc_info=True)
767
  if output_path: cleanup_file(output_path)
768
+ raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during enhancement: {str(e)}")
769
 
770
 
771
  @app.post("/separate", tags=["AI Editing"])
772
  async def separate_sources(
773
  background_tasks: BackgroundTasks,
774
  file: UploadFile = File(..., description="Music audio file to separate into stems."),
775
+ model_key: Optional[str] = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use (defaults to primary)."),
776
  stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
777
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
778
  ):
779
  """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
780
+ if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
781
+ actual_model_key = model_key or SEPARATION_MODEL_KEY
782
+ if actual_model_key not in separation_models:
783
+ logger.error(f"Separation model key '{actual_model_key}' requested but model not loaded.")
784
+ raise HTTPException(status_code=503, detail=f"Separation model '{actual_model_key}' is not loaded or available. Check server startup logs.")
785
 
786
+ loaded_model = separation_models[actual_model_key]
787
  valid_stems = set(loaded_model.sources)
788
  requested_stems = set(s.lower() for s in stems)
 
 
789
 
790
+ # Check if *any* requested stem is valid
791
+ if not requested_stems:
792
+ raise HTTPException(status_code=422, detail="No stems requested for separation.")
793
+ # Check if *all* requested stems are valid for this model
794
+ invalid_stems = requested_stems - valid_stems
795
+ if invalid_stems:
796
+ raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested: {', '.join(invalid_stems)}. Model '{actual_model_key}' provides: {', '.join(valid_stems)}")
797
+
798
+ logger.info(f"Separate request: file='{file.filename}', model='{actual_model_key}', stems={requested_stems}, format='{output_format}'")
799
  input_path = await save_upload_file(file, prefix="separate_in_")
800
  background_tasks.add_task(cleanup_file, input_path)
801
+ stem_output_paths: Dict[str, str] = {} # Store paths of successfully saved stems
802
+ zip_buffer = io.BytesIO(); zipf = None # Initialize zip buffer and file object
803
 
804
  try:
805
+ # Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
806
  audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
807
+
808
  logger.info("Submitting separation task to background thread...")
809
  all_separated_stems_tensors = await asyncio.to_thread(
810
  _run_separation_sync, loaded_model, audio_tensor, current_sr
811
  )
812
+ logger.info("Separation task completed successfully.")
813
 
814
+ # --- Create ZIP file in memory ---
815
+ logger.info("Creating ZIP archive in memory...")
816
  zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
817
+ files_added_to_zip = 0
818
  for stem_name in requested_stems:
819
  if stem_name in all_separated_stems_tensors:
820
  stem_tensor = all_separated_stems_tensors[stem_name]
821
+ stem_path = None # Define stem_path before inner try
822
+ try:
823
+ # Save stem temporarily (save_hf_audio handles tensor)
824
+ # Use the model's native sampling rate for output (DEMUCS_SR)
825
+ stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
826
+ stem_output_paths[stem_name] = stem_path
827
+ # Schedule cleanup AFTER zip is potentially sent
828
+ background_tasks.add_task(cleanup_file, stem_path)
829
+
830
+ # Use a simpler archive name within the zip
831
+ archive_name = f"{stem_name}.{output_format}"
832
+ zipf.write(stem_path, arcname=archive_name)
833
+ files_added_to_zip += 1
834
+ logger.info(f"Added '{archive_name}' to ZIP.")
835
+ except Exception as save_err:
836
+ # Log error saving/zipping this stem but continue with others
837
+ logger.error(f"Failed to save or add stem '{stem_name}' to zip: {save_err}", exc_info=True)
838
+ if stem_path: cleanup_file(stem_path) # Clean up if saved but couldn't zip
839
+ else:
840
+ # This case should be prevented by the earlier validation
841
+ logger.warning(f"Requested stem '{stem_name}' not found in model output (validation error?).")
842
 
843
  zipf.close() # Close zip file BEFORE seeking/reading
844
+ zipf = None # Clear variable to indicate closed
845
 
846
+ if files_added_to_zip == 0:
847
+ logger.error("Failed to add any requested stems to the ZIP archive.")
848
+ raise HTTPException(status_code=500, detail="Failed to generate any of the requested stems.")
849
+
850
+ zip_buffer.seek(0) # Rewind buffer pointer for reading
851
+
852
+ # Create final ZIP filename
853
+ zip_filename = f"separated_{actual_model_key}_{os.path.splitext(file.filename)[0]}.zip"
854
+ logger.info(f"Sending ZIP file: {zip_filename}")
855
  return StreamingResponse(
856
+ iter([zip_buffer.getvalue()]), # StreamingResponse needs an iterator
857
  media_type="application/zip",
858
  headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
859
  )
860
+ except HTTPException as http_exc:
861
+ logger.error(f"HTTP Exception during separation: {http_exc.detail}")
 
862
  if zipf: zipf.close() # Ensure zipfile is closed
863
  if zip_buffer: zip_buffer.close()
864
+ for path in stem_output_paths.values(): cleanup_file(path) # Cleanup successful stems
865
+ raise http_exc
866
+ except Exception as e:
867
+ logger.error(f"Unexpected error during separation operation: {e}", exc_info=True)
868
+ if zipf: zipf.close()
869
+ if zip_buffer: zip_buffer.close()
870
  for path in stem_output_paths.values(): cleanup_file(path)
871
+ raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during separation: {str(e)}")
872
+ finally:
873
+ # Ensure buffer is closed if not already done
874
+ if zip_buffer and not zip_buffer.closed:
875
+ zip_buffer.close()
876
+
877
+
878
+ # --- How to Run ---
879
+ # 1. Ensure FFmpeg is installed and accessible in your PATH.
880
+ # 2. Save this code as `app.py`.
881
+ # 3. Create `requirements.txt` (including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs, python-multipart, protobuf).
882
+ # 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!).
883
+ # 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Use port 7860 for HF Spaces default, remove --reload for production).
884
+ #
885
+ # --- WARNING ---
886
+ # - AI models require SIGNIFICANT RAM (often 8GB+) and CPU/GPU. Inference can be SLOW (minutes). Free HF Spaces might time out or lack resources.
887
+ # - First run downloads models (can take a long time/lots of disk space).
888
+ # - Ensure model names (e.g., "htdemucs") are correct.
889
+ # - MONITOR STARTUP LOGS carefully for model loading success/failure. Errors here will cause 503 errors later.