Athspi commited on
Commit
3f784c4
·
verified ·
1 Parent(s): 1675ffa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -523
app.py CHANGED
@@ -1,15 +1,14 @@
1
- # ----------- START app.py -----------
2
  import os
3
  import uuid
4
  import tempfile
5
  import logging
6
  import asyncio
7
  from typing import List, Optional, Dict, Any
8
- import io
9
- import zipfile
10
 
11
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
12
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
 
 
13
 
14
  # --- Basic Editing Imports ---
15
  from pydub import AudioSegment
@@ -18,62 +17,58 @@ from pydub.exceptions import CouldntDecodeError
18
  # --- AI & Advanced Audio Imports ---
19
  try:
20
  import torch
21
- from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible
22
- # Specific model imports might be needed depending on the chosen approach
23
- # E.g. for Demucs V4 (Hybrid Transformer): from demucs.hdemucs import HDemucs
24
- # from demucs.pretrained import hdemucs_mmi
25
  import soundfile as sf
26
  import numpy as np
27
- import librosa # For resampling if needed
28
- AI_LIBRARIES_AVAILABLE = True
 
 
 
 
 
29
  print("AI and advanced audio libraries loaded.")
30
  except ImportError as e:
31
- print(f"Warning: Error importing AI/Audio libraries: {e}")
32
- print("Ensure torch, transformers, soundfile, librosa are installed.")
33
  print("AI features will be unavailable.")
34
- AI_LIBRARIES_AVAILABLE = False
35
- # Define dummy placeholders if needed, or just rely on AI_LIBRARIES_AVAILABLE flag
36
  torch = None
37
- pipeline = None
38
  sf = None
39
  np = None
40
  librosa = None
41
-
 
42
 
43
  # --- Configuration & Setup ---
44
  TEMP_DIR = tempfile.gettempdir()
45
  os.makedirs(TEMP_DIR, exist_ok=True)
46
 
47
- # Configure logging
48
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
49
  logger = logging.getLogger(__name__)
50
 
51
  # --- Global Variables for Loaded Models ---
52
- # Use dictionaries to potentially hold multiple models of each type later
53
- enhancement_models: Dict[str, Any] = {} # Store model/processor or pipeline
54
- separation_models: Dict[str, Any] = {} # Store model/processor or pipeline
55
-
56
- # Target sampling rates for models (check model cards on Hugging Face!)
57
- # These MUST match the models being loaded in download_models.py and load_hf_models
58
- ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement"
59
- ENHANCEMENT_SR = 16000 # Sepformer uses 16kHz
60
-
61
- # Note: facebook/demucs is deprecated in transformers >4.26. Use specific variants.
62
- # Using facebook/htdemucs_ft for example (requires Demucs v4 style loading)
63
- # Or find a model suitable for AutoModel if needed.
64
- SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Example using a quantized version (smaller, faster CPU)
65
- # SEPARATION_MODEL_ID = "facebook/hdemucs_mmi" # Example for Multi-Media Instructions model (if using demucs lib)
66
- DEMUCS_SR = 44100 # Demucs default is 44.1kHz
67
 
68
- # Define HF_HOME cache directory *within* the container if downloading during build
69
- HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache") # Use HF_HOME from Dockerfile or default
70
 
 
 
 
71
 
72
- # --- Helper Functions (cleanup_file, save_upload_file, load_audio_for_hf, save_hf_audio) ---
73
- # (Include the helper functions from the previous app.py example here)
74
- # ...
 
 
 
 
 
75
  def cleanup_file(file_path: str):
76
- """Safely remove a file."""
77
  try:
78
  if file_path and os.path.exists(file_path):
79
  os.remove(file_path)
@@ -82,9 +77,8 @@ def cleanup_file(file_path: str):
82
  logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
83
 
84
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
85
- """Saves an uploaded file to a temporary location and returns the path."""
86
  _, file_extension = os.path.splitext(upload_file.filename)
87
- if not file_extension: file_extension = ".wav" # Default if no extension
88
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
89
  try:
90
  with open(temp_file_path, "wb") as buffer:
@@ -98,78 +92,83 @@ async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") ->
98
  finally:
99
  await upload_file.close()
100
 
101
- def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]:
102
- """Loads audio using soundfile, converts to mono float32, optionally resamples."""
103
- if not AI_LIBRARIES_AVAILABLE or sf is None or np is None:
104
- raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
105
  try:
106
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
107
  logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
108
 
109
- if audio.ndim > 1 and audio.shape[-1] > 1: # Check last dimension for channels
110
- if audio.shape[0] == min(audio.shape): # If channels are first dim
111
- audio = audio.T # Transpose to (samples, channels)
112
- audio = np.mean(audio, axis=1)
113
- logger.info(f"Converted audio to mono, new shape: {audio.shape}")
114
- elif audio.ndim > 1: # If shape is like (1, N) or (N, 1)
115
- audio = audio.squeeze() # Remove singleton dimension
116
- logger.info(f"Squeezed audio to 1D, new shape: {audio.shape}")
117
 
 
 
118
 
 
119
  if target_sr and orig_sr != target_sr:
120
- if librosa is None:
121
- raise RuntimeError("Librosa is required for resampling but not installed.")
122
  logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
123
- # Ensure audio is contiguous before resampling if necessary
124
- if not audio.flags['C_CONTIGUOUS']:
125
- audio = np.ascontiguousarray(audio)
126
- audio = librosa.resample(y=audio, orig_sr=orig_sr, target_sr=target_sr)
127
- logger.info(f"Resampled audio shape: {audio.shape}")
128
  current_sr = target_sr
 
129
  else:
130
  current_sr = orig_sr
131
 
132
- return audio, current_sr
 
133
 
134
  except Exception as e:
135
  logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
136
  raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")
137
 
138
- def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str:
139
- """Saves a NumPy audio array to a temporary file."""
140
- if not AI_LIBRARIES_AVAILABLE or sf is None or np is None:
141
- raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
142
-
143
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
144
  output_path = os.path.join(TEMP_DIR, output_filename)
145
  try:
146
- logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format}, shape={audio_data.shape})")
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Ensure data is float32 for common formats like wav/flac
149
- if audio_data.dtype != np.float32:
150
- logger.warning(f"Audio data has dtype {audio_data.dtype}, converting to float32.")
151
- audio_data = audio_data.astype(np.float32)
152
 
153
- # Clip data to avoid issues with some formats/players if values go beyond [-1, 1]
154
- audio_data = np.clip(audio_data, -1.0, 1.0)
155
 
156
- # Use soundfile for lossless formats
157
  if output_format.lower() in ['wav', 'flac']:
158
- sf.write(output_path, audio_data, sampling_rate, format=output_format.upper())
159
  else:
160
- # For lossy formats like mp3, use pydub after converting numpy array
161
- logger.debug("Using pydub for lossy format export...")
162
  # Scale float32 [-1, 1] to int16 for pydub
163
- audio_int16 = (audio_data * 32767).astype(np.int16)
164
- if audio_int16.ndim > 1: # Should be mono by now, but double check
165
- logger.warning("Audio data still has multiple dimensions before pydub export, attempting mean.")
166
- audio_int16 = np.mean(audio_int16, axis=1).astype(np.int16)
167
-
168
  segment = AudioSegment(
169
  audio_int16.tobytes(),
170
  frame_rate=sampling_rate,
171
  sample_width=audio_int16.dtype.itemsize,
172
- channels=1 # Assuming mono output from AI models for now
173
  )
174
  segment.export(output_path, format=output_format)
175
 
@@ -179,464 +178,210 @@ def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str
179
  cleanup_file(output_path)
180
  raise HTTPException(status_code=500, detail="Failed to save processed audio.")
181
 
182
- # --- Synchronous AI Inference Functions (_run_enhancement_sync, _run_separation_sync) ---
183
- # (Include the sync functions from the previous app.py example here)
184
- # Make sure they handle potential model loading issues gracefully
185
- # ...
186
- def _run_enhancement_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray:
187
- """Synchronous wrapper for enhancement model inference."""
188
- if not AI_LIBRARIES_AVAILABLE or model_key not in enhancement_models:
189
- raise ValueError(f"Enhancement model '{model_key}' not available or AI libraries missing.")
190
-
191
- model_info = enhancement_models[model_key]
192
- # Adapt based on whether model_info holds a pipeline or model/processor
193
- # This example assumes a pipeline-like object is stored
194
- enhancer = model_info # Assuming pipeline
195
- if not enhancer: raise ValueError(f"Enhancement pipeline '{model_key}' is None.")
196
 
 
 
 
197
  try:
198
- logger.info(f"Running speech enhancement with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...")
199
- # Usage depends heavily on the specific model/pipeline interface
200
- # For SpeechBrain models often used *without* HF pipeline:
201
- # Example: enhanced_wav = enhancer.enhance_batch(torch.tensor(audio_data).unsqueeze(0), lengths=torch.tensor([audio_data.shape[0]]))
202
- # enhanced_audio = enhanced_wav.squeeze(0).cpu().numpy()
 
 
 
 
203
 
204
- # If using a generic HF pipeline:
205
- result = enhancer({"raw": audio_data, "sampling_rate": sampling_rate})
206
- enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output
207
 
 
 
208
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
209
  return enhanced_audio
210
  except Exception as e:
211
- logger.error(f"Error during synchronous enhancement inference with '{model_key}': {e}", exc_info=True)
212
- raise # Re-raise to be caught by the async wrapper
213
 
 
 
 
 
 
 
214
 
215
- def _run_separation_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]:
216
- """Synchronous wrapper for source separation model inference."""
217
- if not AI_LIBRARIES_AVAILABLE or model_key not in separation_models:
218
- raise ValueError(f"Separation model '{model_key}' not available or AI libraries missing.")
219
 
220
- model_info = separation_models[model_key]
221
- model = model_info # Assuming direct model object is stored for Demucs
 
 
 
 
222
 
223
- if not model: raise ValueError(f"Separation model '{model_key}' is None.")
 
 
 
 
 
 
224
 
225
- try:
226
- logger.info(f"Running source separation with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...")
227
-
228
- # Prepare input tensor for Demucs-like models
229
- # Expects (batch, channels, samples), float32
230
- if audio_data.ndim == 1:
231
- # Need stereo for standard Demucs
232
- logger.debug("Separation input is mono, duplicating to create stereo.")
233
- audio_data = np.stack([audio_data, audio_data], axis=0) # (2, samples)
234
- if audio_data.shape[0] != 2:
235
- # If it's somehow (samples, 2), transpose
236
- if audio_data.shape[1] == 2: audio_data = audio_data.T
237
- else: raise ValueError(f"Unexpected input audio shape for separation: {audio_data.shape}")
238
-
239
- audio_tensor = torch.tensor(audio_data, dtype=torch.float32).unsqueeze(0) # (1, 2, samples)
240
-
241
- # Move to model's device (CPU or GPU)
242
- device = next(model.parameters()).device
243
- logger.debug(f"Moving separation tensor to device: {device}")
244
- audio_tensor = audio_tensor.to(device)
245
-
246
- # Perform inference
247
  with torch.no_grad():
248
- logger.debug("Starting model inference for separation...")
249
- # Output shape depends on model, e.g., (batch, stems, channels, samples)
250
- sources = model(audio_tensor)[0] # Remove batch dim
251
- logger.debug(f"Model inference complete, sources shape: {sources.shape}")
252
-
253
- # Detach, move to CPU, convert to numpy
254
- sources_np = sources.detach().cpu().numpy() # (stems, channels, samples)
255
-
256
- # Define stem order based on the *specific* Demucs model used
257
- # This order is for default Demucs v3/v4 (facebook/demucs, facebook/htdemucs_ft, etc.)
258
- stem_names = ['drums', 'bass', 'other', 'vocals']
259
- if sources_np.shape[0] != len(stem_names):
260
- logger.warning(f"Model output {sources_np.shape[0]} stems, expected {len(stem_names)}. Stem names might be incorrect.")
261
- # Fallback names if shape mismatch
262
- stem_names = [f"stem_{i+1}" for i in range(sources_np.shape[0])]
263
-
264
- stems = {}
265
- for i, name in enumerate(stem_names):
266
- # Average channels to get mono stem
267
- mono_stem = np.mean(sources_np[i], axis=0)
268
- stems[name] = mono_stem
269
- logger.debug(f"Extracted stem '{name}', shape: {mono_stem.shape}")
270
-
271
- logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
272
- return stems
273
 
274
  except Exception as e:
275
- logger.error(f"Error during synchronous separation inference with '{model_key}': {e}", exc_info=True)
276
  raise
277
 
278
  # --- Model Loading Function ---
279
- # (Include the load_hf_models function from the previous app.py example here)
280
- # Make sure it uses the correct model IDs and potentially adjusts loading logic
281
- # if using libraries like `demucs` directly.
282
- # ...
283
  def load_hf_models():
284
- """Loads Hugging Face models at startup."""
285
- if not AI_LIBRARIES_AVAILABLE:
286
- logger.warning("Skipping Hugging Face model loading as libraries are missing.")
287
- return
288
-
289
  global enhancement_models, separation_models
 
 
 
290
 
291
- # --- Load Enhancement Model ---
292
- enhancement_key = "speechbrain_enhancer" # Internal key
293
  try:
294
- logger.info(f"Attempting to load enhancement model: {ENHANCEMENT_MODEL_ID}...")
295
- # SpeechBrain models often require specific loading from their toolkit or HF spaces
296
- # This might involve cloning a repo or using specific classes.
297
- # Using HF pipeline if available, otherwise manual load.
298
- # Example using pipeline (might not work for all speechbrain models):
299
- # enhancement_models[enhancement_key] = pipeline(
300
- # "audio-enhancement", # Or appropriate task
301
- # model=ENHANCEMENT_MODEL_ID,
302
- # cache_dir=HF_CACHE_DIR,
303
- # device=0 if torch.cuda.is_available() else -1 # Use GPU if possible
304
- # )
305
- # Manual load might be needed:
306
- # from speechbrain.pretrained import SepformerEnhancement
307
- # enhancer = SepformerEnhancement.from_hparams(
308
- # source=ENHANCEMENT_MODEL_ID,
309
- # savedir=os.path.join(HF_CACHE_DIR, "speechbrain", ENHANCEMENT_MODEL_ID.split('/')[-1]),
310
- # run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"}
311
- # )
312
- # enhancement_models[enhancement_key] = enhancer
313
- logger.warning(f"Actual loading for {ENHANCEMENT_MODEL_ID} skipped - requires SpeechBrain toolkit or specific pipeline setup.")
314
- # To make the endpoint testable without full setup:
315
- # enhancement_models[enhancement_key] = None # Or a dummy function
316
-
317
  except Exception as e:
318
- logger.error(f"Failed to load enhancement model '{ENHANCEMENT_MODEL_ID}': {e}", exc_info=False)
319
-
320
 
321
  # --- Load Separation Model (Demucs) ---
322
- separation_key = "demucs_separator" # Internal key
 
323
  try:
324
- logger.info(f"Attempting to load separation model: {SEPARATION_MODEL_ID}...")
325
- # Loading Demucs models can be complex.
326
- # Option 1: Use AutoModel if the HF Hub version supports it (less common for Demucs)
327
- # Option 2: Use the `demucs` library (recommended if installed: pip install -U demucs)
328
- # Option 3: Find a Transformers-compatible version if available.
329
-
330
- # Example using AutoModel (Try this first, might work for some quantized/HF versions)
331
- try:
332
- # Determine device
333
- device = "cuda" if torch.cuda.is_available() else "cpu"
334
- logger.info(f"Loading Demucs on device: {device}")
335
- # Check if AutoModelForSpeechSeq2Seq is appropriate, might need a different AutoModel class
336
- separation_models[separation_key] = AutoModelForSpeechSeq2Seq.from_pretrained(
337
- SEPARATION_MODEL_ID,
338
- cache_dir=HF_CACHE_DIR
339
- # Add trust_remote_code=True if needed for custom model code on HF hub
340
- ).to(device)
341
-
342
- # Check if the loaded model has an 'eval' method (common for PyTorch models)
343
- if hasattr(separation_models[separation_key], 'eval'):
344
- separation_models[separation_key].eval() # Set to evaluation mode
345
-
346
- logger.info(f"Successfully loaded separation model '{SEPARATION_MODEL_ID}' using AutoModel.")
347
-
348
- except Exception as auto_model_err:
349
- logger.warning(f"Failed to load '{SEPARATION_MODEL_ID}' using AutoModel: {auto_model_err}. Consider installing 'demucs' library.")
350
- separation_models[separation_key] = None # Ensure it's None if loading failed
351
-
352
- # Example using `demucs` library (if installed)
353
- # try:
354
- # import demucs.separate
355
- # model = demucs.apply.load_model(pretrained_model_path_or_url) # Needs adjustment
356
- # separation_models[separation_key] = model
357
- # logger.info(f"Successfully loaded separation model using 'demucs' library.")
358
- # except ImportError:
359
- # logger.error("Cannot load Demucs: 'demucs' library not found. Please run 'pip install -U demucs'.")
360
- # except Exception as demucs_lib_err:
361
- # logger.error(f"Error loading model using 'demucs' library: {demucs_lib_err}")
362
-
363
-
364
  except Exception as e:
365
- logger.error(f"General error loading separation model '{SEPARATION_MODEL_ID}': {e}", exc_info=False)
366
- if separation_key in separation_models: separation_models[separation_key] = None
367
 
368
 
369
- # --- FastAPI App and Endpoints ---
370
  app = FastAPI(
371
  title="AI Audio Editor API",
372
- description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and HF model dependencies.",
373
- version="2.0.0",
374
  )
375
 
376
  @app.on_event("startup")
377
  async def startup_event():
378
- """Load models when the application starts."""
379
- logger.info("Application startup: Loading AI models (this may take time)...")
380
  await asyncio.to_thread(load_hf_models)
381
- logger.info("Model loading process finished.")
382
-
383
 
384
  # --- API Endpoints ---
385
- # (Include / , /trim, /concat, /volume, /convert endpoints here - same as previous version)
386
- # ...
387
  @app.get("/", tags=["General"])
388
  def read_root():
389
- """Root endpoint providing a welcome message and available features."""
390
  features = ["/trim", "/concat", "/volume", "/convert"]
391
  ai_features = []
392
- # Check if models were successfully loaded (i.e., not None)
393
- if any(model is not None for model in enhancement_models.values()): ai_features.append("/enhance")
394
- if any(model is not None for model in separation_models.values()): ai_features.append("/separate")
395
 
396
  return {
397
  "message": "Welcome to the AI Audio Editor API.",
398
  "basic_features": features,
399
- "ai_features": ai_features if ai_features else "None loaded (check logs)",
400
- "notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)."
401
  }
402
 
403
- @app.post("/trim", tags=["Basic Editing"])
404
- async def trim_audio(
405
- background_tasks: BackgroundTasks,
406
- file: UploadFile = File(..., description="Audio file to trim."),
407
- start_ms: int = Form(..., description="Start time in milliseconds."),
408
- end_ms: int = Form(..., description="End time in milliseconds.")
409
- ):
410
- """Trims an audio file to the specified start and end times (in milliseconds)."""
411
- if start_ms < 0 or end_ms <= start_ms:
412
- raise HTTPException(status_code=422, detail="Invalid start/end times. Ensure start_ms >= 0 and end_ms > start_ms.")
413
-
414
- logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
415
- input_path = None
416
- output_path = None
417
- try:
418
- input_path = await save_upload_file(file, prefix="trim_in_")
419
- background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
420
-
421
- # Use Pydub for basic trim
422
- audio = AudioSegment.from_file(input_path)
423
- trimmed_audio = audio[start_ms:end_ms]
424
- logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
425
-
426
- original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
427
- if not original_format or original_format == "tmp": original_format = "mp3"
428
- output_filename = f"trimmed_{uuid.uuid4().hex}.{original_format}"
429
- output_path = os.path.join(TEMP_DIR, output_filename)
430
-
431
- trimmed_audio.export(output_path, format=original_format)
432
- background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
433
-
434
- return FileResponse(
435
- path=output_path,
436
- media_type=f"audio/{original_format}", # Attempt correct media type
437
- filename=f"trimmed_{file.filename}"
438
- )
439
- except CouldntDecodeError:
440
- logger.warning(f"pydub failed to decode: {file.filename}")
441
- raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
442
- except Exception as e:
443
- logger.error(f"Error during trim operation: {e}", exc_info=True)
444
- if output_path: cleanup_file(output_path) # Immediate cleanup on error
445
- if input_path: cleanup_file(input_path) # Immediate cleanup on error
446
- if isinstance(e, HTTPException): raise e
447
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}")
448
-
449
-
450
- @app.post("/concat", tags=["Basic Editing"])
451
- async def concatenate_audio(
452
- background_tasks: BackgroundTasks,
453
- files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
454
- output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
455
- ):
456
- """Concatenates two or more audio files sequentially."""
457
- if len(files) < 2:
458
- raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
459
-
460
- logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
461
- input_paths = []
462
- loaded_audios = []
463
- output_path = None
464
-
465
- try:
466
- combined_audio = AudioSegment.empty()
467
- first_filename_base = "combined"
468
-
469
- for i, file in enumerate(files):
470
- input_path = await save_upload_file(file, prefix=f"concat_{i}_")
471
- input_paths.append(input_path)
472
- background_tasks.add_task(cleanup_file, input_path)
473
- audio = AudioSegment.from_file(input_path)
474
- combined_audio += audio
475
- if i == 0: first_filename_base = os.path.splitext(file.filename)[0]
476
- logger.info(f"Appended '{file.filename}', current total duration: {len(combined_audio)}ms")
477
-
478
- if len(combined_audio) == 0:
479
- raise HTTPException(status_code=500, detail="No audio data after attempting concatenation.")
480
-
481
- output_filename_final = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
482
- output_path = os.path.join(TEMP_DIR, f"concat_out_{uuid.uuid4().hex}.{output_format}")
483
-
484
- combined_audio.export(output_path, format=output_format)
485
- background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
486
-
487
- return FileResponse(
488
- path=output_path,
489
- media_type=f"audio/{output_format}",
490
- filename=output_filename_final
491
- )
492
- except CouldntDecodeError as e:
493
- logger.warning(f"pydub failed to decode one of the concat files: {e}")
494
- raise HTTPException(status_code=415, detail=f"Unsupported format or corrupted file among inputs: {e}")
495
- except Exception as e:
496
- logger.error(f"Error during concat operation: {e}", exc_info=True)
497
- if output_path: cleanup_file(output_path)
498
- for p in input_paths: cleanup_file(p)
499
- if isinstance(e, HTTPException): raise e
500
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}")
501
-
502
- @app.post("/volume", tags=["Basic Editing"])
503
- async def change_volume(
504
- background_tasks: BackgroundTasks,
505
- file: UploadFile = File(..., description="Audio file to adjust volume for."),
506
- change_db: float = Form(..., description="Volume change in decibels (dB). Positive values increase volume, negative values decrease.")
507
- ):
508
- """Adjusts the volume of an audio file by a specified decibel amount."""
509
- logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
510
- input_path = None
511
- output_path = None
512
- try:
513
- input_path = await save_upload_file(file, prefix="volume_in_")
514
- background_tasks.add_task(cleanup_file, input_path)
515
-
516
- audio = AudioSegment.from_file(input_path)
517
- adjusted_audio = audio + change_db
518
- logger.info(f"Volume adjusted by {change_db}dB.")
519
-
520
- original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
521
- if not original_format or original_format == "tmp": original_format = "mp3"
522
- output_filename_final = f"volume_{change_db}dB_{file.filename}"
523
- output_path = os.path.join(TEMP_DIR, f"volume_out_{uuid.uuid4().hex}.{original_format}")
524
-
525
- adjusted_audio.export(output_path, format=original_format)
526
- background_tasks.add_task(cleanup_file, output_path)
527
-
528
- return FileResponse(
529
- path=output_path,
530
- media_type=f"audio/{original_format}",
531
- filename=output_filename_final
532
- )
533
- except CouldntDecodeError:
534
- logger.warning(f"pydub failed to decode: {file.filename}")
535
- raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
536
- except Exception as e:
537
- logger.error(f"Error during volume operation: {e}", exc_info=True)
538
- if output_path: cleanup_file(output_path)
539
- if input_path: cleanup_file(input_path)
540
- if isinstance(e, HTTPException): raise e
541
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}")
542
-
543
- @app.post("/convert", tags=["Basic Editing"])
544
- async def convert_format(
545
- background_tasks: BackgroundTasks,
546
- file: UploadFile = File(..., description="Audio file to convert."),
547
- output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
548
- ):
549
- """Converts an audio file to a different format."""
550
- allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a'}
551
- if output_format.lower() not in allowed_formats:
552
- raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed: {', '.join(allowed_formats)}")
553
 
554
- logger.info(f"Convert request: file='{file.filename}', output_format='{output_format}'")
555
- input_path = None
556
- output_path = None
557
- try:
558
- input_path = await save_upload_file(file, prefix="convert_in_")
559
- background_tasks.add_task(cleanup_file, input_path)
560
-
561
- audio = AudioSegment.from_file(input_path)
562
-
563
- output_format_lower = output_format.lower()
564
- filename_base = os.path.splitext(file.filename)[0]
565
- output_filename_final = f"{filename_base}_converted.{output_format_lower}"
566
- output_path = os.path.join(TEMP_DIR, f"convert_out_{uuid.uuid4().hex}.{output_format_lower}")
567
-
568
- audio.export(output_path, format=output_format_lower)
569
- background_tasks.add_task(cleanup_file, output_path)
570
-
571
- return FileResponse(
572
- path=output_path,
573
- media_type=f"audio/{output_format_lower}",
574
- filename=output_filename_final
575
- )
576
- except CouldntDecodeError:
577
- logger.warning(f"pydub failed to decode: {file.filename}")
578
- raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
579
- except Exception as e:
580
- logger.error(f"Error during convert operation: {e}", exc_info=True)
581
- if output_path: cleanup_file(output_path)
582
- if input_path: cleanup_file(input_path)
583
- if isinstance(e, HTTPException): raise e
584
- else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}")
585
 
 
586
 
587
- # (Include /enhance and /separate AI endpoints here - same as previous version)
588
- # ...
589
  @app.post("/enhance", tags=["AI Editing"])
590
  async def enhance_speech(
591
  background_tasks: BackgroundTasks,
592
  file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
593
- model_key: str = Query("speechbrain_enhancer", description="Internal key of the enhancement model to use."),
594
- output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).")
 
595
  ):
596
- """Enhances speech audio using a pre-loaded AI model (experimental)."""
597
- if not AI_LIBRARIES_AVAILABLE:
598
- raise HTTPException(status_code=501, detail="AI processing libraries not available.")
599
- if model_key not in enhancement_models or enhancement_models[model_key] is None:
600
  logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
601
  raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
602
 
603
- logger.info(f"Enhance request: file='{file.filename}', model_key='{model_key}', format='{output_format}'")
604
- input_path = None
 
 
 
605
  output_path = None
 
606
  try:
607
- input_path = await save_upload_file(file, prefix="enhance_in_")
608
- background_tasks.add_task(cleanup_file, input_path)
609
-
610
- # Load audio, ensure correct SR for the model
611
- logger.debug(f"Loading audio for enhancement, target SR: {ENHANCEMENT_SR}")
612
- audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
613
- if current_sr != ENHANCEMENT_SR: # Should have been resampled, but double check
614
- logger.warning(f"Audio SR after loading is {current_sr}, expected {ENHANCEMENT_SR}. Check resampling.")
615
- # Depending on model strictness, could raise error or proceed cautiously.
616
- # raise HTTPException(status_code=500, detail="Audio resampling failed.")
617
-
618
- # Run inference in a separate thread
619
  logger.info("Submitting enhancement task to background thread...")
620
- enhanced_audio = await asyncio.to_thread(
621
- _run_enhancement_sync, model_key, audio_data, current_sr # Pass key, data, and ACTUAL sr used
622
  )
623
  logger.info("Enhancement task completed.")
624
 
625
- # Save the result
626
- output_path = save_hf_audio(enhanced_audio, ENHANCEMENT_SR, output_format) # Save with model's target SR
627
  background_tasks.add_task(cleanup_file, output_path)
628
 
629
- output_filename_final = f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
630
  return FileResponse(
631
  path=output_path,
632
  media_type=f"audio/{output_format}",
633
- filename=output_filename_final
634
  )
635
-
636
  except Exception as e:
637
  logger.error(f"Error during enhancement operation: {e}", exc_info=True)
638
  if output_path: cleanup_file(output_path)
639
- if input_path: cleanup_file(input_path)
640
  if isinstance(e, HTTPException): raise e
641
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
642
 
@@ -645,96 +390,79 @@ async def enhance_speech(
645
  async def separate_sources(
646
  background_tasks: BackgroundTasks,
647
  file: UploadFile = File(..., description="Music audio file to separate into stems."),
648
- model_key: str = Query("demucs_separator", description="Internal key of the separation model to use."),
649
  stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
650
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
651
  ):
652
- """Separates music into stems (vocals, drums, bass, other) using a pre-loaded AI model (experimental). Returns a ZIP archive."""
653
- if not AI_LIBRARIES_AVAILABLE:
654
- raise HTTPException(status_code=501, detail="AI processing libraries not available.")
655
- if model_key not in separation_models or separation_models[model_key] is None:
656
  logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
657
  raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
658
 
659
- valid_stems = {'vocals', 'drums', 'bass', 'other'} # Based on typical Demucs output
 
660
  requested_stems = set(s.lower() for s in stems)
661
  if not requested_stems.issubset(valid_stems):
662
- # Allow if all stems are requested even if validation set is smaller? Or just error.
663
- raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are generally: {', '.join(valid_stems)}")
664
 
665
- logger.info(f"Separate request: file='{file.filename}', model_key='{model_key}', stems={requested_stems}, format='{output_format}'")
666
- input_path = None
 
667
  stem_output_paths: Dict[str, str] = {}
668
- zip_buffer = io.BytesIO() # Use BytesIO for in-memory ZIP
669
 
670
  try:
671
- input_path = await save_upload_file(file, prefix="separate_in_")
672
- background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
673
-
674
- # Load audio, ensure correct SR for the model
675
- logger.debug(f"Loading audio for separation, target SR: {DEMUCS_SR}")
676
- audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
677
- if current_sr != DEMUCS_SR:
678
- logger.warning(f"Audio SR after loading is {current_sr}, expected {DEMUCS_SR}. Check resampling.")
679
- # raise HTTPException(status_code=500, detail="Audio resampling failed.")
680
 
681
- # Run inference in a separate thread
682
  logger.info("Submitting separation task to background thread...")
683
- all_separated_stems = await asyncio.to_thread(
684
- _run_separation_sync, model_key, audio_data, current_sr # Pass key, data, actual SR
685
  )
686
  logger.info("Separation task completed.")
687
 
688
  # --- Create ZIP file in memory ---
689
- zip_filename_base = f"separated_{os.path.splitext(file.filename)[0]}"
690
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
691
- logger.info(f"Creating ZIP archive in memory...")
692
- found_stems_count = 0
693
  for stem_name in requested_stems:
694
- if stem_name in all_separated_stems:
695
- stem_data = all_separated_stems[stem_name]
696
- if stem_data is None or stem_data.size == 0:
697
- logger.warning(f"Stem '{stem_name}' data is empty, skipping.")
698
- continue
699
-
700
- # Save stem temporarily to disk first (needed for pydub/sf.write)
701
- logger.debug(f"Saving temporary stem file for '{stem_name}'...")
702
- stem_path = save_hf_audio(stem_data, DEMUCS_SR, output_format) # Save with model's target SR
703
- stem_output_paths[stem_name] = stem_path # Keep track for cleanup
704
- background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
705
-
706
- # Add the saved stem file to the ZIP archive
707
- archive_name = f"{stem_name}.{output_format}" # Simple name inside zip
708
  zipf.write(stem_path, arcname=archive_name)
709
  logger.info(f"Added '{archive_name}' to ZIP.")
710
- found_stems_count += 1
711
  else:
712
- logger.warning(f"Requested stem '{stem_name}' not found in model output keys: {list(all_separated_stems.keys())}")
 
713
 
714
- if found_stems_count == 0:
715
- raise HTTPException(status_code=404, detail="None of the requested stems were found or generated successfully.")
716
 
717
- zip_buffer.seek(0) # Rewind buffer pointer
718
-
719
- # Return the ZIP file via StreamingResponse
720
- zip_filename_download = f"{zip_filename_base}.zip"
721
- logger.info(f"Sending ZIP file '{zip_filename_download}'")
722
  return StreamingResponse(
723
- zip_buffer, # Pass the BytesIO buffer directly
724
  media_type="application/zip",
725
- headers={'Content-Disposition': f'attachment; filename="{zip_filename_download}"'}
726
  )
727
-
728
  except Exception as e:
729
  logger.error(f"Error during separation operation: {e}", exc_info=True)
730
- # Cleanup temporary stem files if they exist
731
  for path in stem_output_paths.values(): cleanup_file(path)
732
- # Close buffer just in case (though StreamingResponse should handle it)
733
- # if zip_buffer and not zip_buffer.closed: zip_buffer.close()
734
- if input_path: cleanup_file(input_path)
735
-
736
  if isinstance(e, HTTPException): raise e
737
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
738
 
 
 
739
 
740
- # ----------- END app.py -----------
 
 
 
 
 
 
1
  import os
2
  import uuid
3
  import tempfile
4
  import logging
5
  import asyncio
6
  from typing import List, Optional, Dict, Any
 
 
7
 
8
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
9
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
10
+ import io
11
+ import zipfile
12
 
13
  # --- Basic Editing Imports ---
14
  from pydub import AudioSegment
 
17
  # --- AI & Advanced Audio Imports ---
18
  try:
19
  import torch
20
+ # Transformers only needed if using HF pipelines directly, not for speechbrain/demucs manual loading
21
+ # from transformers import pipeline
 
 
22
  import soundfile as sf
23
  import numpy as np
24
+ import librosa
25
+
26
+ # Specific Model Libraries
27
+ import speechbrain.pretrained
28
+ import demucs.separate
29
+ import demucs.apply
30
+
31
  print("AI and advanced audio libraries loaded.")
32
  except ImportError as e:
33
+ print(f"Error importing AI/Audio libraries: {e}")
34
+ print("Ensure torch, soundfile, librosa, speechbrain, demucs are installed.")
35
  print("AI features will be unavailable.")
 
 
36
  torch = None
 
37
  sf = None
38
  np = None
39
  librosa = None
40
+ speechbrain = None
41
+ demucs = None
42
 
43
  # --- Configuration & Setup ---
44
  TEMP_DIR = tempfile.gettempdir()
45
  os.makedirs(TEMP_DIR, exist_ok=True)
46
 
 
47
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
48
  logger = logging.getLogger(__name__)
49
 
50
  # --- Global Variables for Loaded Models ---
51
+ # Use consistent keys for storing/retrieving models
52
+ ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
53
+ # Choose a default Demucs model (htdemucs is good quality)
54
+ SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ enhancement_models: Dict[str, Any] = {}
57
+ separation_models: Dict[str, Any] = {}
58
 
59
+ # Target sampling rates (confirm from model specifics if necessary)
60
+ ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
61
+ DEMUCS_SR = 44100 # Demucs default is 44.1kHz
62
 
63
+ # --- Device Selection ---
64
+ if torch:
65
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
66
+ logger.info(f"Using device: {DEVICE}")
67
+ else:
68
+ DEVICE = "cpu" # Fallback if torch failed import
69
+
70
+ # --- Helper Functions (cleanup_file, save_upload_file - same as before) ---
71
  def cleanup_file(file_path: str):
 
72
  try:
73
  if file_path and os.path.exists(file_path):
74
  os.remove(file_path)
 
77
  logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
78
 
79
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
 
80
  _, file_extension = os.path.splitext(upload_file.filename)
81
+ if not file_extension: file_extension = ".wav"
82
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
83
  try:
84
  with open(temp_file_path, "wb") as buffer:
 
92
  finally:
93
  await upload_file.close()
94
 
95
+ # --- Audio Loading/Saving for AI Models ---
96
+ def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
97
+ """Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
 
98
  try:
99
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
100
  logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
101
 
102
+ if audio.ndim > 1 and audio.shape[0] > 5: # Check if likely stereo (more than 5 channels unlikely)
103
+ logger.warning(f"Detected {audio.shape[0]} channels, attempting to convert to mono by averaging.")
104
+ audio = np.mean(audio, axis=0) # Average channels if multi-channel
105
+ elif audio.ndim > 1:
106
+ audio = audio[0] # Take first channel if shape is like (1, N)
 
 
 
107
 
108
+ # Convert numpy array to torch tensor
109
+ audio_tensor = torch.from_numpy(audio).float()
110
 
111
+ # Resample if necessary using librosa then convert back to tensor
112
  if target_sr and orig_sr != target_sr:
113
+ if librosa is None: raise RuntimeError("Librosa is required for resampling but not installed.")
 
114
  logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
115
+ # Librosa works on numpy, so convert back temp.
116
+ audio_np = audio_tensor.numpy()
117
+ resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
118
+ audio_tensor = torch.from_numpy(resampled_audio_np).float()
 
119
  current_sr = target_sr
120
+ logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
121
  else:
122
  current_sr = orig_sr
123
 
124
+ # Ensure tensor is on the correct device
125
+ return audio_tensor.to(DEVICE), current_sr
126
 
127
  except Exception as e:
128
  logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
129
  raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")
130
 
131
+ def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
132
+ """Saves audio data (Tensor or NumPy array) to a temporary file."""
 
 
 
133
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
134
  output_path = os.path.join(TEMP_DIR, output_filename)
135
  try:
136
+ logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
137
+
138
+ # Convert tensor to numpy array if needed
139
+ if isinstance(audio_data, torch.Tensor):
140
+ logger.debug("Converting output tensor to NumPy array.")
141
+ # Ensure tensor is on CPU before converting to numpy
142
+ audio_np = audio_data.detach().cpu().numpy()
143
+ elif isinstance(audio_data, np.ndarray):
144
+ audio_np = audio_data
145
+ else:
146
+ raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
147
 
148
+ # Ensure data is float32
149
+ if audio_np.dtype != np.float32:
150
+ logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
151
+ audio_np = audio_np.astype(np.float32)
152
 
153
+ # Clip values to avoid potential issues with formats expecting [-1, 1]
154
+ audio_np = np.clip(audio_np, -1.0, 1.0)
155
 
156
+ # Use soundfile (preferred for wav/flac)
157
  if output_format.lower() in ['wav', 'flac']:
158
+ sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
159
  else:
160
+ # For lossy formats, use pydub
161
+ logger.debug(f"Using pydub to export to lossy format: {output_format}")
162
  # Scale float32 [-1, 1] to int16 for pydub
163
+ audio_int16 = (audio_np * 32767).astype(np.int16)
164
+ # Create AudioSegment (assuming mono for now)
165
+ num_channels = 1 if audio_int16.ndim == 1 else audio_int16.shape[0] # Basic channel check
166
+ if num_channels > 1 : audio_int16=audio_int16[0] # Use first channel if > 1, needs better handling
 
167
  segment = AudioSegment(
168
  audio_int16.tobytes(),
169
  frame_rate=sampling_rate,
170
  sample_width=audio_int16.dtype.itemsize,
171
+ channels=1 # Forcing mono currently
172
  )
173
  segment.export(output_path, format=output_format)
174
 
 
178
  cleanup_file(output_path)
179
  raise HTTPException(status_code=500, detail="Failed to save processed audio.")
180
 
181
+ # --- Synchronous AI Inference Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
184
+ """Synchronous wrapper for SpeechBrain enhancement model inference."""
185
+ if not model: raise ValueError("Enhancement model not loaded")
186
  try:
187
+ logger.info(f"Running speech enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
188
+ # SpeechBrain models usually take tensors directly
189
+ # Add batch dimension if needed (most SB models expect batch)
190
+ if audio_tensor.ndim == 1:
191
+ audio_tensor = audio_tensor.unsqueeze(0)
192
+
193
+ # Move tensor to the same device as the model
194
+ model_device = next(model.parameters()).device
195
+ audio_tensor = audio_tensor.to(model_device)
196
 
197
+ with torch.no_grad():
198
+ # Use enhance_batch for batched input
199
+ enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))
200
 
201
+ # Remove batch dimension from output before returning
202
+ enhanced_audio = enhanced_tensor.squeeze(0).cpu() # Move back to CPU
203
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
204
  return enhanced_audio
205
  except Exception as e:
206
+ logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
207
+ raise
208
 
209
+ def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
210
+ """Synchronous wrapper for Demucs source separation model inference."""
211
+ if not model: raise ValueError("Separation model not loaded")
212
+ if not demucs: raise RuntimeError("Demucs library not available")
213
+ try:
214
+ logger.info(f"Running source separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
215
 
216
+ # Demucs expects audio as (batch, channels, samples)
217
+ # Ensure input tensor is on the correct device
218
+ model_device = next(model.parameters()).device
219
+ audio_tensor = audio_tensor.to(model_device)
220
 
221
+ # Add batch and channel dimensions if mono
222
+ if audio_tensor.ndim == 1:
223
+ audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
224
+ elif audio_tensor.ndim == 2: # Should not happen often if load_audio ensures mono tensor
225
+ logger.warning("Input tensor has 2 dims, assuming (batch, samples), adding channel dim.")
226
+ audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
227
 
228
+ # Ensure correct number of channels expected by the model (usually 2)
229
+ if audio_tensor.shape[1] != model.audio_channels:
230
+ logger.warning(f"Model expects {model.audio_channels} channels, input has {audio_tensor.shape[1]}. Repeating mono channel.")
231
+ audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1) # Repeat mono to match expected channels
232
+
233
+
234
+ logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  with torch.no_grad():
237
+ # Use demucs.apply.apply_model which handles chunking etc.
238
+ # requires ref = audio_tensor.mean(0) # Average channels for reference
239
+ # sources = demucs.apply.apply_model(model, audio_tensor[0], device=model_device, shifts=1, split=True, overlap=0.25)[0] # Process first batch item
240
+
241
+ # OR direct model call if simpler:
242
+ sources = model(audio_tensor)[0] # Output shape (stems, channels, samples) - remove batch dim [0]
243
+
244
+ logger.debug(f"Raw separated sources tensor shape: {sources.shape}") # Should be (num_stems, channels, samples)
245
+
246
+ # Map stems based on the model's sources list
247
+ # Default for htdemucs: drums, bass, other, vocals
248
+ stem_map = {name: sources[i] for i, name in enumerate(model.sources)}
249
+
250
+ # Convert back to mono for simplicity (average channels) and move to CPU
251
+ output_stems = {}
252
+ for name, data in stem_map.items():
253
+ output_stems[name] = data.mean(dim=0).detach().cpu() # Average channels, detach, move to CPU
254
+
255
+ logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
256
+ return output_stems
 
 
 
 
 
257
 
258
  except Exception as e:
259
+ logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
260
  raise
261
 
262
  # --- Model Loading Function ---
 
 
 
 
263
  def load_hf_models():
264
+ """Loads AI models at startup using correct libraries."""
 
 
 
 
265
  global enhancement_models, separation_models
266
+ if torch is None or speechbrain is None or demucs is None:
267
+ logger.error("Core AI libraries (torch, speechbrain, demucs) not available. Skipping model loading.")
268
+ return
269
 
270
+ # --- Load Enhancement Model (SpeechBrain) ---
271
+ enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
272
  try:
273
+ logger.info(f"Loading enhancement model: {enhancement_model_hparams} (using SpeechBrain)...")
274
+ # Ensure SpeechBrain downloads to a writable location if needed (optional)
275
+ # savedir = os.path.join(TEMP_DIR, "speechbrain_models")
276
+ # os.makedirs(savedir, exist_ok=True)
277
+ enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
278
+ source=enhancement_model_hparams,
279
+ # savedir=savedir, # Specify download dir if needed
280
+ run_opts={"device": DEVICE} # Pass device option
281
+ )
282
+ enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer # Store with consistent key
283
+ logger.info(f"Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {DEVICE}.")
 
 
 
 
 
 
 
 
 
 
 
 
284
  except Exception as e:
285
+ logger.error(f"Failed to load enhancement model '{enhancement_model_hparams}': {e}", exc_info=True)
 
286
 
287
  # --- Load Separation Model (Demucs) ---
288
+ # Using a standard pretrained model name from the demucs package
289
+ separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs" or "mdx_extra_q"
290
  try:
291
+ logger.info(f"Loading separation model: {separation_model_name} (using Demucs package)...")
292
+ # This automatically handles downloading the model checkpoint
293
+ separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
294
+ separation_models[SEPARATION_MODEL_KEY] = separator # Store with consistent key
295
+ logger.info(f"Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {DEVICE}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  except Exception as e:
297
+ logger.error(f"Failed to load separation model '{separation_model_name}': {e}", exc_info=True)
298
+ logger.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
299
 
300
 
301
+ # --- FastAPI App ---
302
  app = FastAPI(
303
  title="AI Audio Editor API",
304
+ description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries (torch, speechbrain, demucs).",
305
+ version="2.1.0", # Incremented version
306
  )
307
 
308
  @app.on_event("startup")
309
  async def startup_event():
310
+ logger.info("Application startup: Loading AI models...")
311
+ # Run blocking model load in thread
312
  await asyncio.to_thread(load_hf_models)
313
+ logger.info("Model loading process finished (check logs for success/failure).")
 
314
 
315
  # --- API Endpoints ---
316
+
 
317
  @app.get("/", tags=["General"])
318
  def read_root():
319
+ # ... (root endpoint remains the same) ...
320
  features = ["/trim", "/concat", "/volume", "/convert"]
321
  ai_features = []
322
+ if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
323
+ if separation_models: ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY})")
 
324
 
325
  return {
326
  "message": "Welcome to the AI Audio Editor API.",
327
  "basic_features": features,
328
+ "ai_features": ai_features if ai_features else "None available (check startup logs)",
329
+ "notes": "Requires FFmpeg. AI features require specific models loaded at startup."
330
  }
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
+ # --- Basic Editing Endpoints ---
334
+ # (Add /trim, /concat, /volume, /convert endpoints here - same logic as before)
335
+ # Make sure they use the updated cleanup_file and save_upload_file helpers.
336
+ # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ # --- AI Endpoints (Corrected) ---
339
 
 
 
340
  @app.post("/enhance", tags=["AI Editing"])
341
  async def enhance_speech(
342
  background_tasks: BackgroundTasks,
343
  file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
344
+ # Model ID is less relevant now if only one is loaded, but keep for future flexibility
345
+ model_key: str = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use."),
346
+ output_format: str = Form("wav", description="Output format (wav, flac recommended).")
347
  ):
348
+ """Enhances speech audio using a pre-loaded SpeechBrain model."""
349
+ if torch is None or speechbrain is None:
350
+ raise HTTPException(status_code=501, detail="AI processing libraries (torch, speechbrain) not available.")
351
+ if model_key not in enhancement_models:
352
  logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
353
  raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
354
 
355
+ loaded_model = enhancement_models[model_key]
356
+
357
+ logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
358
+ input_path = await save_upload_file(file, prefix="enhance_in_")
359
+ background_tasks.add_task(cleanup_file, input_path)
360
  output_path = None
361
+
362
  try:
363
+ # Load audio as tensor, ensure correct SR
364
+ # SpeechBrain Sepformer expects 16kHz
365
+ audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
366
+
 
 
 
 
 
 
 
 
367
  logger.info("Submitting enhancement task to background thread...")
368
+ enhanced_audio_tensor = await asyncio.to_thread(
369
+ _run_enhancement_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
370
  )
371
  logger.info("Enhancement task completed.")
372
 
373
+ # Save the result (tensor output from enhancer)
374
+ output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format) # Save at model's SR
375
  background_tasks.add_task(cleanup_file, output_path)
376
 
 
377
  return FileResponse(
378
  path=output_path,
379
  media_type=f"audio/{output_format}",
380
+ filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
381
  )
 
382
  except Exception as e:
383
  logger.error(f"Error during enhancement operation: {e}", exc_info=True)
384
  if output_path: cleanup_file(output_path)
 
385
  if isinstance(e, HTTPException): raise e
386
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
387
 
 
390
  async def separate_sources(
391
  background_tasks: BackgroundTasks,
392
  file: UploadFile = File(..., description="Music audio file to separate into stems."),
393
+ model_key: str = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use."),
394
  stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
395
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
396
  ):
397
+ """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
398
+ if torch is None or demucs is None:
399
+ raise HTTPException(status_code=501, detail="AI processing libraries (torch, demucs) not available.")
400
+ if model_key not in separation_models:
401
  logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
402
  raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
403
 
404
+ loaded_model = separation_models[model_key]
405
+ valid_stems = set(loaded_model.sources) # Get stems directly from loaded model
406
  requested_stems = set(s.lower() for s in stems)
407
  if not requested_stems.issubset(valid_stems):
408
+ raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Model '{model_key}' provides: {', '.join(valid_stems)}")
 
409
 
410
+ logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
411
+ input_path = await save_upload_file(file, prefix="separate_in_")
412
+ background_tasks.add_task(cleanup_file, input_path)
413
  stem_output_paths: Dict[str, str] = {}
414
+ zip_buffer = None
415
 
416
  try:
417
+ # Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
418
+ audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
 
 
 
 
 
 
 
419
 
 
420
  logger.info("Submitting separation task to background thread...")
421
+ all_separated_stems_tensors = await asyncio.to_thread(
422
+ _run_separation_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
423
  )
424
  logger.info("Separation task completed.")
425
 
426
  # --- Create ZIP file in memory ---
427
+ zip_buffer = io.BytesIO()
428
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
429
+ # Save only the requested stems
 
430
  for stem_name in requested_stems:
431
+ if stem_name in all_separated_stems_tensors:
432
+ stem_tensor = all_separated_stems_tensors[stem_name]
433
+ # Save stem temporarily (save_hf_audio handles tensor)
434
+ # Use the model's native sampling rate for output
435
+ stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
436
+ stem_output_paths[stem_name] = stem_path
437
+ background_tasks.add_task(cleanup_file, stem_path)
438
+
439
+ archive_name = f"{stem_name}_{os.path.splitext(file.filename)[0]}.{output_format}"
 
 
 
 
 
440
  zipf.write(stem_path, arcname=archive_name)
441
  logger.info(f"Added '{archive_name}' to ZIP.")
 
442
  else:
443
+ # This case should be prevented by the earlier validation
444
+ logger.warning(f"Requested stem '{stem_name}' not found in model output (should not happen).")
445
 
446
+ zip_buffer.seek(0)
 
447
 
448
+ zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
 
 
 
 
449
  return StreamingResponse(
450
+ zip_buffer,
451
  media_type="application/zip",
452
+ headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
453
  )
 
454
  except Exception as e:
455
  logger.error(f"Error during separation operation: {e}", exc_info=True)
 
456
  for path in stem_output_paths.values(): cleanup_file(path)
457
+ if zip_buffer: zip_buffer.close()
 
 
 
458
  if isinstance(e, HTTPException): raise e
459
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
460
 
461
+ # --- Add back the basic editing endpoints (/trim, /concat, /volume, /convert) here ---
462
+ # ... (Remember to include them) ...
463
 
464
+ # --- How to Run ---
465
+ # 1. Ensure FFmpeg is installed.
466
+ # 2. Save code as `app.py`. Create/update `requirements.txt`.
467
+ # 3. Install: `pip install -r requirements.txt` (May take significant time/space!)
468
+ # 4. Run: `uvicorn app:app --reload --host 0.0.0.0`