Athspi commited on
Commit
3ef3c9e
·
verified ·
1 Parent(s): fd0bdf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +466 -180
app.py CHANGED
@@ -1,14 +1,15 @@
 
1
  import os
2
  import uuid
3
  import tempfile
4
  import logging
5
  import asyncio
6
  from typing import List, Optional, Dict, Any
 
 
7
 
8
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
9
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
10
- import io # For zip file in memory
11
- import zipfile
12
 
13
  # --- Basic Editing Imports ---
14
  from pydub import AudioSegment
@@ -19,20 +20,26 @@ try:
19
  import torch
20
  from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible
21
  # Specific model imports might be needed depending on the chosen approach
 
 
22
  import soundfile as sf
23
  import numpy as np
24
  import librosa # For resampling if needed
 
25
  print("AI and advanced audio libraries loaded.")
26
  except ImportError as e:
27
- print(f"Error importing AI/Audio libraries: {e}")
28
  print("Ensure torch, transformers, soundfile, librosa are installed.")
29
  print("AI features will be unavailable.")
 
 
30
  torch = None
31
  pipeline = None
32
  sf = None
33
  np = None
34
  librosa = None
35
 
 
36
  # --- Configuration & Setup ---
37
  TEMP_DIR = tempfile.gettempdir()
38
  os.makedirs(TEMP_DIR, exist_ok=True)
@@ -43,15 +50,28 @@ logger = logging.getLogger(__name__)
43
 
44
  # --- Global Variables for Loaded Models ---
45
  # Use dictionaries to potentially hold multiple models of each type later
46
- enhancement_pipelines: Dict[str, Any] = {}
47
- separation_models: Dict[str, Any] = {} # Might store pipeline or model/processor pair
48
 
49
  # Target sampling rates for models (check model cards on Hugging Face!)
50
- ENHANCEMENT_SR = 16000 # Example for speechbrain/sepformer
51
- DEMUCS_SR = 44100 # Demucs default
 
52
 
53
- # --- Helper Functions ---
 
 
 
 
 
54
 
 
 
 
 
 
 
 
55
  def cleanup_file(file_path: str):
56
  """Safely remove a file."""
57
  try:
@@ -63,7 +83,6 @@ def cleanup_file(file_path: str):
63
 
64
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
65
  """Saves an uploaded file to a temporary location and returns the path."""
66
- # Generate a unique temporary file path
67
  _, file_extension = os.path.splitext(upload_file.filename)
68
  if not file_extension: file_extension = ".wav" # Default if no extension
69
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
@@ -79,26 +98,32 @@ async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") ->
79
  finally:
80
  await upload_file.close()
81
 
82
- # --- Audio Loading/Saving for AI Models ---
83
-
84
  def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]:
85
  """Loads audio using soundfile, converts to mono float32, optionally resamples."""
 
 
86
  try:
87
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
88
  logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
89
 
90
- # Convert to mono if stereo
91
- if audio.ndim > 1 and audio.shape[1] > 1:
92
- # Simple averaging for mono conversion
93
- audio = np.mean(audio, axis=1)
94
- logger.info("Converted audio to mono")
 
 
 
 
95
 
96
- # Resample if necessary
97
  if target_sr and orig_sr != target_sr:
98
  if librosa is None:
99
  raise RuntimeError("Librosa is required for resampling but not installed.")
100
  logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
101
- audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
 
 
 
102
  logger.info(f"Resampled audio shape: {audio.shape}")
103
  current_sr = target_sr
104
  else:
@@ -112,22 +137,34 @@ def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[
112
 
113
  def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str:
114
  """Saves a NumPy audio array to a temporary file."""
 
 
 
115
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
116
  output_path = os.path.join(TEMP_DIR, output_filename)
117
  try:
118
- logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
119
- # Ensure data is float32 for common formats like wav/flac, pydub handles mp3 etc.
 
120
  if audio_data.dtype != np.float32:
 
121
  audio_data = audio_data.astype(np.float32)
122
 
 
 
 
123
  # Use soundfile for lossless formats
124
  if output_format.lower() in ['wav', 'flac']:
125
  sf.write(output_path, audio_data, sampling_rate, format=output_format.upper())
126
  else:
127
  # For lossy formats like mp3, use pydub after converting numpy array
128
- # Convert numpy array [-1.0, 1.0] float32 to pydub segment
129
- # Scale to 16-bit integer range for pydub if needed
130
  audio_int16 = (audio_data * 32767).astype(np.int16)
 
 
 
 
131
  segment = AudioSegment(
132
  audio_int16.tobytes(),
133
  frame_rate=sampling_rate,
@@ -142,118 +179,191 @@ def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str
142
  cleanup_file(output_path)
143
  raise HTTPException(status_code=500, detail="Failed to save processed audio.")
144
 
145
- # --- Synchronous AI Inference Functions (to be run in threads) ---
146
-
147
- def _run_enhancement_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray:
 
 
148
  """Synchronous wrapper for enhancement model inference."""
149
- if not model_pipeline: raise ValueError("Enhancement model not loaded")
 
 
 
 
 
 
 
 
150
  try:
151
- logger.info(f"Running speech enhancement (input shape: {audio_data.shape}, SR: {sampling_rate})...")
152
- # Pipeline usage depends heavily on the specific pipeline
153
- # Example for a hypothetical 'audio-enhancement' pipeline:
154
- result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate})
 
 
 
 
155
  enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output
 
156
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
157
  return enhanced_audio
158
  except Exception as e:
159
- logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
160
  raise # Re-raise to be caught by the async wrapper
161
 
162
- def _run_separation_sync(model_pipeline: Any, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]:
 
163
  """Synchronous wrapper for source separation model inference."""
164
- if not model_pipeline: raise ValueError("Separation model not loaded")
 
 
 
 
 
 
 
165
  try:
166
- logger.info(f"Running source separation (input shape: {audio_data.shape}, SR: {sampling_rate})...")
167
- # Usage depends on the separation model/pipeline
168
- # Example for a hypothetical 'audio-source-separation' pipeline:
169
- # Note: Actual Demucs might need different handling (e.g., direct model call)
170
- # result = model_pipeline({"raw": audio_data, "sampling_rate": sampling_rate})
171
-
172
- # Manual example closer to raw Demucs model (if not using pipeline)
173
- # Assuming `separation_models['demucs']` holds the loaded Demucs model instance
174
- model = separation_models.get('demucs')
175
- if not model: raise ValueError("Demucs model not loaded correctly")
176
-
177
- # Demucs expects stereo input in shape (batch, channels, samples)
178
- # Convert mono to stereo if needed, add batch dim
179
  if audio_data.ndim == 1:
180
- audio_data = np.stack([audio_data, audio_data], axis=0) # Create stereo from mono
181
- audio_tensor = torch.tensor(audio_data).unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
 
182
 
183
- # Move to GPU if available and model is on GPU
184
  device = next(model.parameters()).device
 
185
  audio_tensor = audio_tensor.to(device)
186
 
 
187
  with torch.no_grad():
188
- sources = model(audio_tensor)[0] # Output shape (stems, channels, samples)
 
 
 
189
 
190
  # Detach, move to CPU, convert to numpy
191
- sources_np = sources.detach().cpu().numpy()
192
-
193
- # Convert back to mono for simplicity (average channels)
194
- stems = {
195
- 'drums': np.mean(sources_np[0], axis=0),
196
- 'bass': np.mean(sources_np[1], axis=0),
197
- 'other': np.mean(sources_np[2], axis=0),
198
- 'vocals': np.mean(sources_np[3], axis=0),
199
- }
200
- # Important: The order (drums, bass, other, vocals) is specific to Demucs v3/v4 default model
 
 
 
 
 
 
 
201
  logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
202
  return stems
203
 
204
  except Exception as e:
205
- logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
206
  raise
207
 
208
  # --- Model Loading Function ---
 
 
 
 
209
  def load_hf_models():
210
  """Loads Hugging Face models at startup."""
211
- global enhancement_pipelines, separation_models
212
- if torch is None or pipeline is None:
213
- logger.warning("Torch or Transformers not available. Skipping Hugging Face model loading.")
214
  return
215
 
 
 
216
  # --- Load Enhancement Model ---
217
- # Using speechbrain/sepformer-whamr-enhancement via pipeline (check HF for exact pipeline task)
218
- # Or load model/processor manually if no direct pipeline exists
219
- enhancement_model_id = "speechbrain/sepformer-whamr-enhancement" # Example ID
220
  try:
221
- logger.info(f"Loading enhancement model: {enhancement_model_id}...")
222
- # Use appropriate task, might be 'audio-enhancement', 'audio-classification' with custom logic, or manual loading
223
- # If no pipeline, load manually:
224
- # enhancement_processor = AutoProcessor.from_pretrained(...)
225
- # enhancement_model = AutoModel...from_pretrained(...)
226
- # enhancement_pipelines['speechbrain_sepformer'] = {"processor": enhancement_processor, "model": enhancement_model}
227
-
228
- # For now, let's assume a placeholder pipeline exists or skip if complex
229
- # enhancement_pipelines['speechbrain_sepformer'] = pipeline("audio-enhancement", model=enhancement_model_id)
230
- logger.warning(f"Skipping load for {enhancement_model_id} - requires specific pipeline or manual setup.")
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  except Exception as e:
233
- logger.error(f"Failed to load enhancement model '{enhancement_model_id}': {e}", exc_info=False)
234
 
235
 
236
  # --- Load Separation Model (Demucs) ---
237
- # Demucs is often used directly, not via a standard HF pipeline task
238
- separation_model_id = "facebook/demucs" # Or specific variant like facebook/hybrid_demucs
239
  try:
240
- logger.info(f"Loading separation model: {separation_model_id}...")
241
- # Demucs usually requires loading the model directly
242
- # Using AutoModel might work for some variants if configured correctly in HF hub
243
- # separation_models['demucs'] = AutoModel.from_pretrained(separation_model_id) # Check if this works
244
-
245
- # More typically, you might need to install the 'demucs' package itself: pip install -U demucs
246
- # import demucs.separate
247
- # model = demucs.apply.load_model(separation_model_id or demucs.pretrained.DEFAULT_MODEL) # Using demucs package
248
- # separation_models['demucs'] = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- # For now, simulate loading failure as direct AutoModel might not work
251
- raise NotImplementedError("Demucs loading typically requires the 'demucs' package or specific manual loading.")
252
- logger.info(f"Separation model '{separation_model_id}' loaded.")
253
 
254
  except Exception as e:
255
- logger.error(f"Failed to load separation model '{separation_model_id}': {e}", exc_info=False)
256
- logger.warning("Note: Demucs loading often requires 'pip install demucs' and specific loading code, not just AutoModel.")
257
 
258
 
259
  # --- FastAPI App and Endpoints ---
@@ -266,82 +376,267 @@ app = FastAPI(
266
  @app.on_event("startup")
267
  async def startup_event():
268
  """Load models when the application starts."""
269
- logger.info("Application startup: Loading AI models...")
270
- # Running model loading in a separate thread to avoid blocking startup completely
271
- # although startup will wait for this thread to finish.
272
- # Consider truly background loading if startup time is critical.
273
  await asyncio.to_thread(load_hf_models)
274
- logger.info("Model loading process finished (check logs for success/failure).")
275
-
276
 
277
- # --- Basic Editing Endpoints (Mostly Unchanged) ---
278
 
 
 
 
279
  @app.get("/", tags=["General"])
280
  def read_root():
281
  """Root endpoint providing a welcome message and available features."""
282
  features = ["/trim", "/concat", "/volume", "/convert"]
283
  ai_features = []
284
- if enhancement_pipelines: ai_features.append("/enhance")
285
- if separation_models: ai_features.append("/separate")
 
286
 
287
  return {
288
  "message": "Welcome to the AI Audio Editor API.",
289
  "basic_features": features,
290
- "ai_features": ai_features if ai_features else "None available (models might have failed to load)",
291
  "notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)."
292
  }
293
 
294
- # /trim, /concat, /volume, /convert endpoints remain largely the same as before
295
- # Ensure they use the updated save_upload_file and cleanup logic
296
- # (Code for these endpoints omitted for brevity - refer to previous example)
297
- # ... Add /trim, /concat, /volume, /convert endpoints here ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
 
 
 
 
 
 
 
 
299
 
300
- # --- AI Endpoints ---
 
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  @app.post("/enhance", tags=["AI Editing"])
303
  async def enhance_speech(
304
  background_tasks: BackgroundTasks,
305
  file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
306
- model_id: str = Query("speechbrain_sepformer", description="ID of the enhancement model to use (if multiple loaded)."),
307
  output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).")
308
  ):
309
  """Enhances speech audio using a pre-loaded AI model (experimental)."""
310
- if torch is None or sf is None or np is None:
311
  raise HTTPException(status_code=501, detail="AI processing libraries not available.")
312
- if model_id not in enhancement_pipelines:
313
- raise HTTPException(status_code=503, detail=f"Enhancement model '{model_id}' is not loaded or available.")
 
314
 
315
- logger.info(f"Enhance request: file='{file.filename}', model='{model_id}', format='{output_format}'")
316
- input_path = await save_upload_file(file, prefix="enhance_in_")
317
- background_tasks.add_task(cleanup_file, input_path)
318
-
319
- output_path = None # Define output_path before try block
320
  try:
 
 
 
321
  # Load audio, ensure correct SR for the model
 
322
  audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
 
 
 
 
323
 
324
  # Run inference in a separate thread
325
  logger.info("Submitting enhancement task to background thread...")
326
- model_pipeline = enhancement_pipelines[model_id] # Get the specific loaded pipeline/model
327
  enhanced_audio = await asyncio.to_thread(
328
- _run_enhancement_sync, model_pipeline, audio_data, current_sr
329
  )
330
  logger.info("Enhancement task completed.")
331
 
332
  # Save the result
333
- output_path = save_hf_audio(enhanced_audio, current_sr, output_format) # Use current_sr (which is target_sr)
334
  background_tasks.add_task(cleanup_file, output_path)
335
 
 
336
  return FileResponse(
337
  path=output_path,
338
  media_type=f"audio/{output_format}",
339
- filename=f"enhanced_{file.filename}"
340
  )
341
 
342
  except Exception as e:
343
  logger.error(f"Error during enhancement operation: {e}", exc_info=True)
344
- if output_path: cleanup_file(output_path) # Cleanup if error occurred after output started saving
 
345
  if isinstance(e, HTTPException): raise e
346
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
347
 
@@ -350,105 +645,96 @@ async def enhance_speech(
350
  async def separate_sources(
351
  background_tasks: BackgroundTasks,
352
  file: UploadFile = File(..., description="Music audio file to separate into stems."),
353
- model_id: str = Query("demucs", description="ID of the separation model to use."),
354
  stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
355
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
356
  ):
357
- """Separates music into stems (vocals, drums, bass, other) using Demucs (experimental). Returns a ZIP archive."""
358
- if torch is None or sf is None or np is None:
359
  raise HTTPException(status_code=501, detail="AI processing libraries not available.")
360
- if model_id not in separation_models:
361
- raise HTTPException(status_code=503, detail=f"Separation model '{model_id}' is not loaded or available.")
 
362
 
363
- valid_stems = {'vocals', 'drums', 'bass', 'other'}
364
  requested_stems = set(s.lower() for s in stems)
365
  if not requested_stems.issubset(valid_stems):
366
- raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are: {', '.join(valid_stems)}")
367
-
368
- logger.info(f"Separate request: file='{file.filename}', model='{model_id}', stems={requested_stems}, format='{output_format}'")
369
- input_path = await save_upload_file(file, prefix="separate_in_")
370
- background_tasks.add_task(cleanup_file, input_path)
371
 
 
 
372
  stem_output_paths: Dict[str, str] = {}
373
- zip_buffer = None
 
374
  try:
375
- # Load audio, ensure correct SR for the model (Demucs uses 44.1kHz)
 
 
 
 
376
  audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
 
 
 
377
 
378
  # Run inference in a separate thread
379
  logger.info("Submitting separation task to background thread...")
380
- model = separation_models[model_id] # Get the specific loaded model
381
  all_separated_stems = await asyncio.to_thread(
382
- _run_separation_sync, model, audio_data, current_sr
383
  )
384
  logger.info("Separation task completed.")
385
 
386
  # --- Create ZIP file in memory ---
387
- zip_buffer = io.BytesIO()
388
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
389
- # Save only the requested stems
 
390
  for stem_name in requested_stems:
391
  if stem_name in all_separated_stems:
392
  stem_data = all_separated_stems[stem_name]
 
 
 
 
393
  # Save stem temporarily to disk first (needed for pydub/sf.write)
394
- stem_path = save_hf_audio(stem_data, current_sr, output_format)
 
395
  stem_output_paths[stem_name] = stem_path # Keep track for cleanup
396
  background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
397
 
398
  # Add the saved stem file to the ZIP archive
399
- archive_name = f"{stem_name}_{os.path.basename(input_path)}.{output_format}"
400
  zipf.write(stem_path, arcname=archive_name)
401
  logger.info(f"Added '{archive_name}' to ZIP.")
 
402
  else:
403
- logger.warning(f"Requested stem '{stem_name}' not found in model output.")
 
 
 
404
 
405
  zip_buffer.seek(0) # Rewind buffer pointer
406
 
407
- # Return the ZIP file
408
- zip_filename = f"separated_stems_{os.path.splitext(file.filename)[0]}.zip"
 
409
  return StreamingResponse(
410
- zip_buffer,
411
  media_type="application/zip",
412
- headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
413
  )
414
 
415
  except Exception as e:
416
  logger.error(f"Error during separation operation: {e}", exc_info=True)
417
- # Cleanup any stems that were saved before zipping failed
418
- for path in stem_output_paths.values():
419
- cleanup_file(path)
420
- if zip_buffer: zip_buffer.close() # Close memory buffer
 
421
 
422
  if isinstance(e, HTTPException): raise e
423
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
424
 
425
 
426
- # --- How to Run ---
427
- # 1. Ensure FFmpeg is installed and accessible.
428
- # 2. Save this code as `app.py`.
429
- # 3. Create `requirements.txt` (as shown above).
430
- # 4. Install dependencies: `pip install -r requirements.txt` (This can take time!)
431
- # 5. Run the FastAPI server: `uvicorn app:app --reload --host 0.0.0.0`
432
- # (Use --host 0.0.0.0 for external/Docker access. --reload is optional)
433
- #
434
- # --- WARNING ---
435
- # - AI models require SIGNIFICANT RAM and CPU/GPU. Inference can be SLOW.
436
- # - The first run will download models, which can take a long time and lots of disk space.
437
- # - Ensure the specific model IDs used are correct and compatible with HF libraries.
438
- # - Model loading at startup might fail if dependencies are missing or resources are insufficient. Check logs!
439
- #
440
- # --- Example Usage (using curl) ---
441
- #
442
- # **Enhance:** (Enhance noisy_speech.wav)
443
- # curl -X POST "http://127.0.0.1:8000/enhance?model_id=speechbrain_sepformer" \
444
- # -F "file=@noisy_speech.wav" \
445
- # -F "output_format=wav" \
446
- # --output enhanced_speech.wav
447
- #
448
- # **Separate:** (Separate vocals and drums from music.mp3)
449
- # curl -X POST "http://127.0.0.1:8000/separate?model_id=demucs" \
450
- # -F "[email protected]" \
451
- # -F "stems=vocals" \
452
- # -F "stems=drums" \
453
- # -F "output_format=mp3" \
454
- # --output separated_stems.zip
 
1
+ # ----------- START app.py -----------
2
  import os
3
  import uuid
4
  import tempfile
5
  import logging
6
  import asyncio
7
  from typing import List, Optional, Dict, Any
8
+ import io
9
+ import zipfile
10
 
11
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
12
  from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
 
 
13
 
14
  # --- Basic Editing Imports ---
15
  from pydub import AudioSegment
 
20
  import torch
21
  from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor # Using pipeline for simplicity where possible
22
  # Specific model imports might be needed depending on the chosen approach
23
+ # E.g. for Demucs V4 (Hybrid Transformer): from demucs.hdemucs import HDemucs
24
+ # from demucs.pretrained import hdemucs_mmi
25
  import soundfile as sf
26
  import numpy as np
27
  import librosa # For resampling if needed
28
+ AI_LIBRARIES_AVAILABLE = True
29
  print("AI and advanced audio libraries loaded.")
30
  except ImportError as e:
31
+ print(f"Warning: Error importing AI/Audio libraries: {e}")
32
  print("Ensure torch, transformers, soundfile, librosa are installed.")
33
  print("AI features will be unavailable.")
34
+ AI_LIBRARIES_AVAILABLE = False
35
+ # Define dummy placeholders if needed, or just rely on AI_LIBRARIES_AVAILABLE flag
36
  torch = None
37
  pipeline = None
38
  sf = None
39
  np = None
40
  librosa = None
41
 
42
+
43
  # --- Configuration & Setup ---
44
  TEMP_DIR = tempfile.gettempdir()
45
  os.makedirs(TEMP_DIR, exist_ok=True)
 
50
 
51
  # --- Global Variables for Loaded Models ---
52
  # Use dictionaries to potentially hold multiple models of each type later
53
+ enhancement_models: Dict[str, Any] = {} # Store model/processor or pipeline
54
+ separation_models: Dict[str, Any] = {} # Store model/processor or pipeline
55
 
56
  # Target sampling rates for models (check model cards on Hugging Face!)
57
+ # These MUST match the models being loaded in download_models.py and load_hf_models
58
+ ENHANCEMENT_MODEL_ID = "speechbrain/sepformer-whamr-enhancement"
59
+ ENHANCEMENT_SR = 16000 # Sepformer uses 16kHz
60
 
61
+ # Note: facebook/demucs is deprecated in transformers >4.26. Use specific variants.
62
+ # Using facebook/htdemucs_ft for example (requires Demucs v4 style loading)
63
+ # Or find a model suitable for AutoModel if needed.
64
+ SEPARATION_MODEL_ID = "facebook/demucs_quantized" # Example using a quantized version (smaller, faster CPU)
65
+ # SEPARATION_MODEL_ID = "facebook/hdemucs_mmi" # Example for Multi-Media Instructions model (if using demucs lib)
66
+ DEMUCS_SR = 44100 # Demucs default is 44.1kHz
67
 
68
+ # Define HF_HOME cache directory *within* the container if downloading during build
69
+ HF_CACHE_DIR = os.environ.get("HF_HOME", "/app/hf_cache") # Use HF_HOME from Dockerfile or default
70
+
71
+
72
+ # --- Helper Functions (cleanup_file, save_upload_file, load_audio_for_hf, save_hf_audio) ---
73
+ # (Include the helper functions from the previous app.py example here)
74
+ # ...
75
  def cleanup_file(file_path: str):
76
  """Safely remove a file."""
77
  try:
 
83
 
84
  async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
85
  """Saves an uploaded file to a temporary location and returns the path."""
 
86
  _, file_extension = os.path.splitext(upload_file.filename)
87
  if not file_extension: file_extension = ".wav" # Default if no extension
88
  temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
 
98
  finally:
99
  await upload_file.close()
100
 
 
 
101
  def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[np.ndarray, int]:
102
  """Loads audio using soundfile, converts to mono float32, optionally resamples."""
103
+ if not AI_LIBRARIES_AVAILABLE or sf is None or np is None:
104
+ raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
105
  try:
106
  audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
107
  logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")
108
 
109
+ if audio.ndim > 1 and audio.shape[-1] > 1: # Check last dimension for channels
110
+ if audio.shape[0] == min(audio.shape): # If channels are first dim
111
+ audio = audio.T # Transpose to (samples, channels)
112
+ audio = np.mean(audio, axis=1)
113
+ logger.info(f"Converted audio to mono, new shape: {audio.shape}")
114
+ elif audio.ndim > 1: # If shape is like (1, N) or (N, 1)
115
+ audio = audio.squeeze() # Remove singleton dimension
116
+ logger.info(f"Squeezed audio to 1D, new shape: {audio.shape}")
117
+
118
 
 
119
  if target_sr and orig_sr != target_sr:
120
  if librosa is None:
121
  raise RuntimeError("Librosa is required for resampling but not installed.")
122
  logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
123
+ # Ensure audio is contiguous before resampling if necessary
124
+ if not audio.flags['C_CONTIGUOUS']:
125
+ audio = np.ascontiguousarray(audio)
126
+ audio = librosa.resample(y=audio, orig_sr=orig_sr, target_sr=target_sr)
127
  logger.info(f"Resampled audio shape: {audio.shape}")
128
  current_sr = target_sr
129
  else:
 
137
 
138
  def save_hf_audio(audio_data: np.ndarray, sampling_rate: int, output_format: str = "wav") -> str:
139
  """Saves a NumPy audio array to a temporary file."""
140
+ if not AI_LIBRARIES_AVAILABLE or sf is None or np is None:
141
+ raise HTTPException(status_code=501, detail="Audio processing libraries (soundfile, numpy) not available.")
142
+
143
  output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
144
  output_path = os.path.join(TEMP_DIR, output_filename)
145
  try:
146
+ logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format}, shape={audio_data.shape})")
147
+
148
+ # Ensure data is float32 for common formats like wav/flac
149
  if audio_data.dtype != np.float32:
150
+ logger.warning(f"Audio data has dtype {audio_data.dtype}, converting to float32.")
151
  audio_data = audio_data.astype(np.float32)
152
 
153
+ # Clip data to avoid issues with some formats/players if values go beyond [-1, 1]
154
+ audio_data = np.clip(audio_data, -1.0, 1.0)
155
+
156
  # Use soundfile for lossless formats
157
  if output_format.lower() in ['wav', 'flac']:
158
  sf.write(output_path, audio_data, sampling_rate, format=output_format.upper())
159
  else:
160
  # For lossy formats like mp3, use pydub after converting numpy array
161
+ logger.debug("Using pydub for lossy format export...")
162
+ # Scale float32 [-1, 1] to int16 for pydub
163
  audio_int16 = (audio_data * 32767).astype(np.int16)
164
+ if audio_int16.ndim > 1: # Should be mono by now, but double check
165
+ logger.warning("Audio data still has multiple dimensions before pydub export, attempting mean.")
166
+ audio_int16 = np.mean(audio_int16, axis=1).astype(np.int16)
167
+
168
  segment = AudioSegment(
169
  audio_int16.tobytes(),
170
  frame_rate=sampling_rate,
 
179
  cleanup_file(output_path)
180
  raise HTTPException(status_code=500, detail="Failed to save processed audio.")
181
 
182
+ # --- Synchronous AI Inference Functions (_run_enhancement_sync, _run_separation_sync) ---
183
+ # (Include the sync functions from the previous app.py example here)
184
+ # Make sure they handle potential model loading issues gracefully
185
+ # ...
186
+ def _run_enhancement_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray:
187
  """Synchronous wrapper for enhancement model inference."""
188
+ if not AI_LIBRARIES_AVAILABLE or model_key not in enhancement_models:
189
+ raise ValueError(f"Enhancement model '{model_key}' not available or AI libraries missing.")
190
+
191
+ model_info = enhancement_models[model_key]
192
+ # Adapt based on whether model_info holds a pipeline or model/processor
193
+ # This example assumes a pipeline-like object is stored
194
+ enhancer = model_info # Assuming pipeline
195
+ if not enhancer: raise ValueError(f"Enhancement pipeline '{model_key}' is None.")
196
+
197
  try:
198
+ logger.info(f"Running speech enhancement with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...")
199
+ # Usage depends heavily on the specific model/pipeline interface
200
+ # For SpeechBrain models often used *without* HF pipeline:
201
+ # Example: enhanced_wav = enhancer.enhance_batch(torch.tensor(audio_data).unsqueeze(0), lengths=torch.tensor([audio_data.shape[0]]))
202
+ # enhanced_audio = enhanced_wav.squeeze(0).cpu().numpy()
203
+
204
+ # If using a generic HF pipeline:
205
+ result = enhancer({"raw": audio_data, "sampling_rate": sampling_rate})
206
  enhanced_audio = result["audio"]["array"] # Adjust based on actual pipeline output
207
+
208
  logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
209
  return enhanced_audio
210
  except Exception as e:
211
+ logger.error(f"Error during synchronous enhancement inference with '{model_key}': {e}", exc_info=True)
212
  raise # Re-raise to be caught by the async wrapper
213
 
214
+
215
+ def _run_separation_sync(model_key: str, audio_data: np.ndarray, sampling_rate: int) -> Dict[str, np.ndarray]:
216
  """Synchronous wrapper for source separation model inference."""
217
+ if not AI_LIBRARIES_AVAILABLE or model_key not in separation_models:
218
+ raise ValueError(f"Separation model '{model_key}' not available or AI libraries missing.")
219
+
220
+ model_info = separation_models[model_key]
221
+ model = model_info # Assuming direct model object is stored for Demucs
222
+
223
+ if not model: raise ValueError(f"Separation model '{model_key}' is None.")
224
+
225
  try:
226
+ logger.info(f"Running source separation with '{model_key}' (input shape: {audio_data.shape}, SR: {sampling_rate})...")
227
+
228
+ # Prepare input tensor for Demucs-like models
229
+ # Expects (batch, channels, samples), float32
 
 
 
 
 
 
 
 
 
230
  if audio_data.ndim == 1:
231
+ # Need stereo for standard Demucs
232
+ logger.debug("Separation input is mono, duplicating to create stereo.")
233
+ audio_data = np.stack([audio_data, audio_data], axis=0) # (2, samples)
234
+ if audio_data.shape[0] != 2:
235
+ # If it's somehow (samples, 2), transpose
236
+ if audio_data.shape[1] == 2: audio_data = audio_data.T
237
+ else: raise ValueError(f"Unexpected input audio shape for separation: {audio_data.shape}")
238
+
239
+ audio_tensor = torch.tensor(audio_data, dtype=torch.float32).unsqueeze(0) # (1, 2, samples)
240
 
241
+ # Move to model's device (CPU or GPU)
242
  device = next(model.parameters()).device
243
+ logger.debug(f"Moving separation tensor to device: {device}")
244
  audio_tensor = audio_tensor.to(device)
245
 
246
+ # Perform inference
247
  with torch.no_grad():
248
+ logger.debug("Starting model inference for separation...")
249
+ # Output shape depends on model, e.g., (batch, stems, channels, samples)
250
+ sources = model(audio_tensor)[0] # Remove batch dim
251
+ logger.debug(f"Model inference complete, sources shape: {sources.shape}")
252
 
253
  # Detach, move to CPU, convert to numpy
254
+ sources_np = sources.detach().cpu().numpy() # (stems, channels, samples)
255
+
256
+ # Define stem order based on the *specific* Demucs model used
257
+ # This order is for default Demucs v3/v4 (facebook/demucs, facebook/htdemucs_ft, etc.)
258
+ stem_names = ['drums', 'bass', 'other', 'vocals']
259
+ if sources_np.shape[0] != len(stem_names):
260
+ logger.warning(f"Model output {sources_np.shape[0]} stems, expected {len(stem_names)}. Stem names might be incorrect.")
261
+ # Fallback names if shape mismatch
262
+ stem_names = [f"stem_{i+1}" for i in range(sources_np.shape[0])]
263
+
264
+ stems = {}
265
+ for i, name in enumerate(stem_names):
266
+ # Average channels to get mono stem
267
+ mono_stem = np.mean(sources_np[i], axis=0)
268
+ stems[name] = mono_stem
269
+ logger.debug(f"Extracted stem '{name}', shape: {mono_stem.shape}")
270
+
271
  logger.info(f"Separation complete. Found stems: {list(stems.keys())}")
272
  return stems
273
 
274
  except Exception as e:
275
+ logger.error(f"Error during synchronous separation inference with '{model_key}': {e}", exc_info=True)
276
  raise
277
 
278
  # --- Model Loading Function ---
279
+ # (Include the load_hf_models function from the previous app.py example here)
280
+ # Make sure it uses the correct model IDs and potentially adjusts loading logic
281
+ # if using libraries like `demucs` directly.
282
+ # ...
283
  def load_hf_models():
284
  """Loads Hugging Face models at startup."""
285
+ if not AI_LIBRARIES_AVAILABLE:
286
+ logger.warning("Skipping Hugging Face model loading as libraries are missing.")
 
287
  return
288
 
289
+ global enhancement_models, separation_models
290
+
291
  # --- Load Enhancement Model ---
292
+ enhancement_key = "speechbrain_enhancer" # Internal key
 
 
293
  try:
294
+ logger.info(f"Attempting to load enhancement model: {ENHANCEMENT_MODEL_ID}...")
295
+ # SpeechBrain models often require specific loading from their toolkit or HF spaces
296
+ # This might involve cloning a repo or using specific classes.
297
+ # Using HF pipeline if available, otherwise manual load.
298
+ # Example using pipeline (might not work for all speechbrain models):
299
+ # enhancement_models[enhancement_key] = pipeline(
300
+ # "audio-enhancement", # Or appropriate task
301
+ # model=ENHANCEMENT_MODEL_ID,
302
+ # cache_dir=HF_CACHE_DIR,
303
+ # device=0 if torch.cuda.is_available() else -1 # Use GPU if possible
304
+ # )
305
+ # Manual load might be needed:
306
+ # from speechbrain.pretrained import SepformerEnhancement
307
+ # enhancer = SepformerEnhancement.from_hparams(
308
+ # source=ENHANCEMENT_MODEL_ID,
309
+ # savedir=os.path.join(HF_CACHE_DIR, "speechbrain", ENHANCEMENT_MODEL_ID.split('/')[-1]),
310
+ # run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"}
311
+ # )
312
+ # enhancement_models[enhancement_key] = enhancer
313
+ logger.warning(f"Actual loading for {ENHANCEMENT_MODEL_ID} skipped - requires SpeechBrain toolkit or specific pipeline setup.")
314
+ # To make the endpoint testable without full setup:
315
+ # enhancement_models[enhancement_key] = None # Or a dummy function
316
 
317
  except Exception as e:
318
+ logger.error(f"Failed to load enhancement model '{ENHANCEMENT_MODEL_ID}': {e}", exc_info=False)
319
 
320
 
321
  # --- Load Separation Model (Demucs) ---
322
+ separation_key = "demucs_separator" # Internal key
 
323
  try:
324
+ logger.info(f"Attempting to load separation model: {SEPARATION_MODEL_ID}...")
325
+ # Loading Demucs models can be complex.
326
+ # Option 1: Use AutoModel if the HF Hub version supports it (less common for Demucs)
327
+ # Option 2: Use the `demucs` library (recommended if installed: pip install -U demucs)
328
+ # Option 3: Find a Transformers-compatible version if available.
329
+
330
+ # Example using AutoModel (Try this first, might work for some quantized/HF versions)
331
+ try:
332
+ # Determine device
333
+ device = "cuda" if torch.cuda.is_available() else "cpu"
334
+ logger.info(f"Loading Demucs on device: {device}")
335
+ # Check if AutoModelForSpeechSeq2Seq is appropriate, might need a different AutoModel class
336
+ separation_models[separation_key] = AutoModelForSpeechSeq2Seq.from_pretrained(
337
+ SEPARATION_MODEL_ID,
338
+ cache_dir=HF_CACHE_DIR
339
+ # Add trust_remote_code=True if needed for custom model code on HF hub
340
+ ).to(device)
341
+
342
+ # Check if the loaded model has an 'eval' method (common for PyTorch models)
343
+ if hasattr(separation_models[separation_key], 'eval'):
344
+ separation_models[separation_key].eval() # Set to evaluation mode
345
+
346
+ logger.info(f"Successfully loaded separation model '{SEPARATION_MODEL_ID}' using AutoModel.")
347
+
348
+ except Exception as auto_model_err:
349
+ logger.warning(f"Failed to load '{SEPARATION_MODEL_ID}' using AutoModel: {auto_model_err}. Consider installing 'demucs' library.")
350
+ separation_models[separation_key] = None # Ensure it's None if loading failed
351
+
352
+ # Example using `demucs` library (if installed)
353
+ # try:
354
+ # import demucs.separate
355
+ # model = demucs.apply.load_model(pretrained_model_path_or_url) # Needs adjustment
356
+ # separation_models[separation_key] = model
357
+ # logger.info(f"Successfully loaded separation model using 'demucs' library.")
358
+ # except ImportError:
359
+ # logger.error("Cannot load Demucs: 'demucs' library not found. Please run 'pip install -U demucs'.")
360
+ # except Exception as demucs_lib_err:
361
+ # logger.error(f"Error loading model using 'demucs' library: {demucs_lib_err}")
362
 
 
 
 
363
 
364
  except Exception as e:
365
+ logger.error(f"General error loading separation model '{SEPARATION_MODEL_ID}': {e}", exc_info=False)
366
+ if separation_key in separation_models: separation_models[separation_key] = None
367
 
368
 
369
  # --- FastAPI App and Endpoints ---
 
376
  @app.on_event("startup")
377
  async def startup_event():
378
  """Load models when the application starts."""
379
+ logger.info("Application startup: Loading AI models (this may take time)...")
 
 
 
380
  await asyncio.to_thread(load_hf_models)
381
+ logger.info("Model loading process finished.")
 
382
 
 
383
 
384
+ # --- API Endpoints ---
385
+ # (Include / , /trim, /concat, /volume, /convert endpoints here - same as previous version)
386
+ # ...
387
  @app.get("/", tags=["General"])
388
  def read_root():
389
  """Root endpoint providing a welcome message and available features."""
390
  features = ["/trim", "/concat", "/volume", "/convert"]
391
  ai_features = []
392
+ # Check if models were successfully loaded (i.e., not None)
393
+ if any(model is not None for model in enhancement_models.values()): ai_features.append("/enhance")
394
+ if any(model is not None for model in separation_models.values()): ai_features.append("/separate")
395
 
396
  return {
397
  "message": "Welcome to the AI Audio Editor API.",
398
  "basic_features": features,
399
+ "ai_features": ai_features if ai_features else "None loaded (check logs)",
400
  "notes": "Requires FFmpeg. AI features require specific models loaded at startup (check logs)."
401
  }
402
 
403
+ @app.post("/trim", tags=["Basic Editing"])
404
+ async def trim_audio(
405
+ background_tasks: BackgroundTasks,
406
+ file: UploadFile = File(..., description="Audio file to trim."),
407
+ start_ms: int = Form(..., description="Start time in milliseconds."),
408
+ end_ms: int = Form(..., description="End time in milliseconds.")
409
+ ):
410
+ """Trims an audio file to the specified start and end times (in milliseconds)."""
411
+ if start_ms < 0 or end_ms <= start_ms:
412
+ raise HTTPException(status_code=422, detail="Invalid start/end times. Ensure start_ms >= 0 and end_ms > start_ms.")
413
+
414
+ logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
415
+ input_path = None
416
+ output_path = None
417
+ try:
418
+ input_path = await save_upload_file(file, prefix="trim_in_")
419
+ background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
420
+
421
+ # Use Pydub for basic trim
422
+ audio = AudioSegment.from_file(input_path)
423
+ trimmed_audio = audio[start_ms:end_ms]
424
+ logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
425
+
426
+ original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
427
+ if not original_format or original_format == "tmp": original_format = "mp3"
428
+ output_filename = f"trimmed_{uuid.uuid4().hex}.{original_format}"
429
+ output_path = os.path.join(TEMP_DIR, output_filename)
430
+
431
+ trimmed_audio.export(output_path, format=original_format)
432
+ background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
433
+
434
+ return FileResponse(
435
+ path=output_path,
436
+ media_type=f"audio/{original_format}", # Attempt correct media type
437
+ filename=f"trimmed_{file.filename}"
438
+ )
439
+ except CouldntDecodeError:
440
+ logger.warning(f"pydub failed to decode: {file.filename}")
441
+ raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
442
+ except Exception as e:
443
+ logger.error(f"Error during trim operation: {e}", exc_info=True)
444
+ if output_path: cleanup_file(output_path) # Immediate cleanup on error
445
+ if input_path: cleanup_file(input_path) # Immediate cleanup on error
446
+ if isinstance(e, HTTPException): raise e
447
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during trimming: {str(e)}")
448
+
449
+
450
+ @app.post("/concat", tags=["Basic Editing"])
451
+ async def concatenate_audio(
452
+ background_tasks: BackgroundTasks,
453
+ files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
454
+ output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
455
+ ):
456
+ """Concatenates two or more audio files sequentially."""
457
+ if len(files) < 2:
458
+ raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
459
+
460
+ logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
461
+ input_paths = []
462
+ loaded_audios = []
463
+ output_path = None
464
+
465
+ try:
466
+ combined_audio = AudioSegment.empty()
467
+ first_filename_base = "combined"
468
 
469
+ for i, file in enumerate(files):
470
+ input_path = await save_upload_file(file, prefix=f"concat_{i}_")
471
+ input_paths.append(input_path)
472
+ background_tasks.add_task(cleanup_file, input_path)
473
+ audio = AudioSegment.from_file(input_path)
474
+ combined_audio += audio
475
+ if i == 0: first_filename_base = os.path.splitext(file.filename)[0]
476
+ logger.info(f"Appended '{file.filename}', current total duration: {len(combined_audio)}ms")
477
 
478
+ if len(combined_audio) == 0:
479
+ raise HTTPException(status_code=500, detail="No audio data after attempting concatenation.")
480
 
481
+ output_filename_final = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
482
+ output_path = os.path.join(TEMP_DIR, f"concat_out_{uuid.uuid4().hex}.{output_format}")
483
+
484
+ combined_audio.export(output_path, format=output_format)
485
+ background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
486
+
487
+ return FileResponse(
488
+ path=output_path,
489
+ media_type=f"audio/{output_format}",
490
+ filename=output_filename_final
491
+ )
492
+ except CouldntDecodeError as e:
493
+ logger.warning(f"pydub failed to decode one of the concat files: {e}")
494
+ raise HTTPException(status_code=415, detail=f"Unsupported format or corrupted file among inputs: {e}")
495
+ except Exception as e:
496
+ logger.error(f"Error during concat operation: {e}", exc_info=True)
497
+ if output_path: cleanup_file(output_path)
498
+ for p in input_paths: cleanup_file(p)
499
+ if isinstance(e, HTTPException): raise e
500
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during concatenation: {str(e)}")
501
+
502
+ @app.post("/volume", tags=["Basic Editing"])
503
+ async def change_volume(
504
+ background_tasks: BackgroundTasks,
505
+ file: UploadFile = File(..., description="Audio file to adjust volume for."),
506
+ change_db: float = Form(..., description="Volume change in decibels (dB). Positive values increase volume, negative values decrease.")
507
+ ):
508
+ """Adjusts the volume of an audio file by a specified decibel amount."""
509
+ logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
510
+ input_path = None
511
+ output_path = None
512
+ try:
513
+ input_path = await save_upload_file(file, prefix="volume_in_")
514
+ background_tasks.add_task(cleanup_file, input_path)
515
+
516
+ audio = AudioSegment.from_file(input_path)
517
+ adjusted_audio = audio + change_db
518
+ logger.info(f"Volume adjusted by {change_db}dB.")
519
+
520
+ original_format = os.path.splitext(file.filename)[1][1:].lower() or "mp3"
521
+ if not original_format or original_format == "tmp": original_format = "mp3"
522
+ output_filename_final = f"volume_{change_db}dB_{file.filename}"
523
+ output_path = os.path.join(TEMP_DIR, f"volume_out_{uuid.uuid4().hex}.{original_format}")
524
+
525
+ adjusted_audio.export(output_path, format=original_format)
526
+ background_tasks.add_task(cleanup_file, output_path)
527
+
528
+ return FileResponse(
529
+ path=output_path,
530
+ media_type=f"audio/{original_format}",
531
+ filename=output_filename_final
532
+ )
533
+ except CouldntDecodeError:
534
+ logger.warning(f"pydub failed to decode: {file.filename}")
535
+ raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
536
+ except Exception as e:
537
+ logger.error(f"Error during volume operation: {e}", exc_info=True)
538
+ if output_path: cleanup_file(output_path)
539
+ if input_path: cleanup_file(input_path)
540
+ if isinstance(e, HTTPException): raise e
541
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during volume adjustment: {str(e)}")
542
+
543
+ @app.post("/convert", tags=["Basic Editing"])
544
+ async def convert_format(
545
+ background_tasks: BackgroundTasks,
546
+ file: UploadFile = File(..., description="Audio file to convert."),
547
+ output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
548
+ ):
549
+ """Converts an audio file to a different format."""
550
+ allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a'}
551
+ if output_format.lower() not in allowed_formats:
552
+ raise HTTPException(status_code=422, detail=f"Invalid output format. Allowed: {', '.join(allowed_formats)}")
553
+
554
+ logger.info(f"Convert request: file='{file.filename}', output_format='{output_format}'")
555
+ input_path = None
556
+ output_path = None
557
+ try:
558
+ input_path = await save_upload_file(file, prefix="convert_in_")
559
+ background_tasks.add_task(cleanup_file, input_path)
560
+
561
+ audio = AudioSegment.from_file(input_path)
562
+
563
+ output_format_lower = output_format.lower()
564
+ filename_base = os.path.splitext(file.filename)[0]
565
+ output_filename_final = f"{filename_base}_converted.{output_format_lower}"
566
+ output_path = os.path.join(TEMP_DIR, f"convert_out_{uuid.uuid4().hex}.{output_format_lower}")
567
+
568
+ audio.export(output_path, format=output_format_lower)
569
+ background_tasks.add_task(cleanup_file, output_path)
570
+
571
+ return FileResponse(
572
+ path=output_path,
573
+ media_type=f"audio/{output_format_lower}",
574
+ filename=output_filename_final
575
+ )
576
+ except CouldntDecodeError:
577
+ logger.warning(f"pydub failed to decode: {file.filename}")
578
+ raise HTTPException(status_code=415, detail="Unsupported audio format or corrupted file.")
579
+ except Exception as e:
580
+ logger.error(f"Error during convert operation: {e}", exc_info=True)
581
+ if output_path: cleanup_file(output_path)
582
+ if input_path: cleanup_file(input_path)
583
+ if isinstance(e, HTTPException): raise e
584
+ else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during format conversion: {str(e)}")
585
+
586
+
587
+ # (Include /enhance and /separate AI endpoints here - same as previous version)
588
+ # ...
589
  @app.post("/enhance", tags=["AI Editing"])
590
  async def enhance_speech(
591
  background_tasks: BackgroundTasks,
592
  file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
593
+ model_key: str = Query("speechbrain_enhancer", description="Internal key of the enhancement model to use."),
594
  output_format: str = Form("wav", description="Output format for the enhanced audio (wav, flac recommended).")
595
  ):
596
  """Enhances speech audio using a pre-loaded AI model (experimental)."""
597
+ if not AI_LIBRARIES_AVAILABLE:
598
  raise HTTPException(status_code=501, detail="AI processing libraries not available.")
599
+ if model_key not in enhancement_models or enhancement_models[model_key] is None:
600
+ logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
601
+ raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")
602
 
603
+ logger.info(f"Enhance request: file='{file.filename}', model_key='{model_key}', format='{output_format}'")
604
+ input_path = None
605
+ output_path = None
 
 
606
  try:
607
+ input_path = await save_upload_file(file, prefix="enhance_in_")
608
+ background_tasks.add_task(cleanup_file, input_path)
609
+
610
  # Load audio, ensure correct SR for the model
611
+ logger.debug(f"Loading audio for enhancement, target SR: {ENHANCEMENT_SR}")
612
  audio_data, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
613
+ if current_sr != ENHANCEMENT_SR: # Should have been resampled, but double check
614
+ logger.warning(f"Audio SR after loading is {current_sr}, expected {ENHANCEMENT_SR}. Check resampling.")
615
+ # Depending on model strictness, could raise error or proceed cautiously.
616
+ # raise HTTPException(status_code=500, detail="Audio resampling failed.")
617
 
618
  # Run inference in a separate thread
619
  logger.info("Submitting enhancement task to background thread...")
 
620
  enhanced_audio = await asyncio.to_thread(
621
+ _run_enhancement_sync, model_key, audio_data, current_sr # Pass key, data, and ACTUAL sr used
622
  )
623
  logger.info("Enhancement task completed.")
624
 
625
  # Save the result
626
+ output_path = save_hf_audio(enhanced_audio, ENHANCEMENT_SR, output_format) # Save with model's target SR
627
  background_tasks.add_task(cleanup_file, output_path)
628
 
629
+ output_filename_final = f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
630
  return FileResponse(
631
  path=output_path,
632
  media_type=f"audio/{output_format}",
633
+ filename=output_filename_final
634
  )
635
 
636
  except Exception as e:
637
  logger.error(f"Error during enhancement operation: {e}", exc_info=True)
638
+ if output_path: cleanup_file(output_path)
639
+ if input_path: cleanup_file(input_path)
640
  if isinstance(e, HTTPException): raise e
641
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")
642
 
 
645
  async def separate_sources(
646
  background_tasks: BackgroundTasks,
647
  file: UploadFile = File(..., description="Music audio file to separate into stems."),
648
+ model_key: str = Query("demucs_separator", description="Internal key of the separation model to use."),
649
  stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
650
  output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
651
  ):
652
+ """Separates music into stems (vocals, drums, bass, other) using a pre-loaded AI model (experimental). Returns a ZIP archive."""
653
+ if not AI_LIBRARIES_AVAILABLE:
654
  raise HTTPException(status_code=501, detail="AI processing libraries not available.")
655
+ if model_key not in separation_models or separation_models[model_key] is None:
656
+ logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
657
+ raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")
658
 
659
+ valid_stems = {'vocals', 'drums', 'bass', 'other'} # Based on typical Demucs output
660
  requested_stems = set(s.lower() for s in stems)
661
  if not requested_stems.issubset(valid_stems):
662
+ # Allow if all stems are requested even if validation set is smaller? Or just error.
663
+ raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Valid stems are generally: {', '.join(valid_stems)}")
 
 
 
664
 
665
+ logger.info(f"Separate request: file='{file.filename}', model_key='{model_key}', stems={requested_stems}, format='{output_format}'")
666
+ input_path = None
667
  stem_output_paths: Dict[str, str] = {}
668
+ zip_buffer = io.BytesIO() # Use BytesIO for in-memory ZIP
669
+
670
  try:
671
+ input_path = await save_upload_file(file, prefix="separate_in_")
672
+ background_tasks.add_task(cleanup_file, input_path) # Schedule input cleanup
673
+
674
+ # Load audio, ensure correct SR for the model
675
+ logger.debug(f"Loading audio for separation, target SR: {DEMUCS_SR}")
676
  audio_data, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
677
+ if current_sr != DEMUCS_SR:
678
+ logger.warning(f"Audio SR after loading is {current_sr}, expected {DEMUCS_SR}. Check resampling.")
679
+ # raise HTTPException(status_code=500, detail="Audio resampling failed.")
680
 
681
  # Run inference in a separate thread
682
  logger.info("Submitting separation task to background thread...")
 
683
  all_separated_stems = await asyncio.to_thread(
684
+ _run_separation_sync, model_key, audio_data, current_sr # Pass key, data, actual SR
685
  )
686
  logger.info("Separation task completed.")
687
 
688
  # --- Create ZIP file in memory ---
689
+ zip_filename_base = f"separated_{os.path.splitext(file.filename)[0]}"
690
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
691
+ logger.info(f"Creating ZIP archive in memory...")
692
+ found_stems_count = 0
693
  for stem_name in requested_stems:
694
  if stem_name in all_separated_stems:
695
  stem_data = all_separated_stems[stem_name]
696
+ if stem_data is None or stem_data.size == 0:
697
+ logger.warning(f"Stem '{stem_name}' data is empty, skipping.")
698
+ continue
699
+
700
  # Save stem temporarily to disk first (needed for pydub/sf.write)
701
+ logger.debug(f"Saving temporary stem file for '{stem_name}'...")
702
+ stem_path = save_hf_audio(stem_data, DEMUCS_SR, output_format) # Save with model's target SR
703
  stem_output_paths[stem_name] = stem_path # Keep track for cleanup
704
  background_tasks.add_task(cleanup_file, stem_path) # Schedule cleanup
705
 
706
  # Add the saved stem file to the ZIP archive
707
+ archive_name = f"{stem_name}.{output_format}" # Simple name inside zip
708
  zipf.write(stem_path, arcname=archive_name)
709
  logger.info(f"Added '{archive_name}' to ZIP.")
710
+ found_stems_count += 1
711
  else:
712
+ logger.warning(f"Requested stem '{stem_name}' not found in model output keys: {list(all_separated_stems.keys())}")
713
+
714
+ if found_stems_count == 0:
715
+ raise HTTPException(status_code=404, detail="None of the requested stems were found or generated successfully.")
716
 
717
  zip_buffer.seek(0) # Rewind buffer pointer
718
 
719
+ # Return the ZIP file via StreamingResponse
720
+ zip_filename_download = f"{zip_filename_base}.zip"
721
+ logger.info(f"Sending ZIP file '{zip_filename_download}'")
722
  return StreamingResponse(
723
+ zip_buffer, # Pass the BytesIO buffer directly
724
  media_type="application/zip",
725
+ headers={'Content-Disposition': f'attachment; filename="{zip_filename_download}"'}
726
  )
727
 
728
  except Exception as e:
729
  logger.error(f"Error during separation operation: {e}", exc_info=True)
730
+ # Cleanup temporary stem files if they exist
731
+ for path in stem_output_paths.values(): cleanup_file(path)
732
+ # Close buffer just in case (though StreamingResponse should handle it)
733
+ # if zip_buffer and not zip_buffer.closed: zip_buffer.close()
734
+ if input_path: cleanup_file(input_path)
735
 
736
  if isinstance(e, HTTPException): raise e
737
  else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")
738
 
739
 
740
+ # ----------- END app.py -----------