Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -25,8 +25,11 @@ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(messag
|
|
25 |
ch = logging.StreamHandler()
|
26 |
ch.setLevel(logging.INFO)
|
27 |
ch.setFormatter(formatter)
|
28 |
-
|
|
|
|
|
29 |
|
|
|
30 |
try:
|
31 |
logger_init.info("Importing torch...")
|
32 |
import torch
|
@@ -45,189 +48,333 @@ try:
|
|
45 |
AI_LIBS_AVAILABLE = True
|
46 |
except ImportError as e:
|
47 |
logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
|
48 |
-
logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed.")
|
49 |
logger_init.error("AI features will be unavailable.")
|
|
|
50 |
torch = None
|
51 |
sf = None
|
52 |
np = None
|
53 |
librosa = None
|
54 |
speechbrain = None
|
55 |
demucs = None
|
56 |
-
AI_LIBS_AVAILABLE = False
|
57 |
|
58 |
# --- Configuration & Setup ---
|
59 |
TEMP_DIR = tempfile.gettempdir()
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
# Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
|
63 |
-
|
|
|
64 |
|
65 |
# --- Global Variables for Loaded Models ---
|
66 |
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
|
67 |
-
|
|
|
68 |
|
69 |
enhancement_models: Dict[str, Any] = {}
|
70 |
separation_models: Dict[str, Any] = {}
|
71 |
|
72 |
-
|
73 |
-
|
|
|
74 |
|
75 |
# --- Device Selection ---
|
76 |
-
if torch:
|
77 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
78 |
-
logger_init.info(f"
|
79 |
else:
|
80 |
-
DEVICE = "cpu"
|
81 |
-
logger_init.info("Torch not available, defaulting device to CPU.")
|
82 |
|
83 |
|
84 |
-
# --- Helper Functions
|
|
|
85 |
def cleanup_file(file_path: str):
|
86 |
"""Safely remove a file."""
|
87 |
try:
|
88 |
-
if file_path and os.path.exists(file_path):
|
89 |
os.remove(file_path)
|
90 |
# logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
|
91 |
except Exception as e:
|
|
|
92 |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
|
93 |
|
94 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
95 |
"""Saves an uploaded file to a temporary location and returns the path."""
|
|
|
|
|
|
|
96 |
_, file_extension = os.path.splitext(upload_file.filename)
|
|
|
97 |
if not file_extension: file_extension = ".wav"
|
98 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
|
|
99 |
try:
|
|
|
100 |
with open(temp_file_path, "wb") as buffer:
|
101 |
-
|
102 |
-
|
|
|
|
|
103 |
return temp_file_path
|
104 |
except Exception as e:
|
105 |
-
logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
|
106 |
-
cleanup_file(temp_file_path)
|
107 |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
|
108 |
finally:
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
# --- Audio Loading/Saving Functions (same as before) ---
|
112 |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
|
113 |
-
"""Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
|
114 |
-
|
|
|
|
|
|
|
|
|
115 |
try:
|
116 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
logger.
|
125 |
-
audio = np.mean(audio, axis=
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
|
|
|
|
|
|
|
|
127 |
audio_tensor = torch.from_numpy(audio).float()
|
|
|
|
|
|
|
128 |
if target_sr and orig_sr != target_sr:
|
129 |
-
if librosa is None: raise RuntimeError("Librosa missing")
|
130 |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
|
|
|
131 |
audio_np = audio_tensor.numpy()
|
132 |
-
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
|
133 |
audio_tensor = torch.from_numpy(resampled_audio_np).float()
|
134 |
current_sr = target_sr
|
135 |
-
|
136 |
-
|
|
|
137 |
return audio_tensor.to(DEVICE), current_sr
|
|
|
|
|
|
|
|
|
|
|
138 |
except Exception as e:
|
139 |
-
logger.error(f"
|
140 |
cleanup_file(file_path)
|
141 |
-
raise HTTPException(status_code=
|
142 |
|
143 |
|
144 |
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
|
145 |
"""Saves audio data (Tensor or NumPy array) to a temporary file."""
|
146 |
-
|
|
|
|
|
147 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
|
148 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
149 |
try:
|
150 |
-
|
|
|
|
|
151 |
if isinstance(audio_data, torch.Tensor):
|
|
|
|
|
152 |
audio_np = audio_data.detach().cpu().numpy()
|
153 |
elif isinstance(audio_data, np.ndarray):
|
154 |
audio_np = audio_data
|
155 |
else:
|
156 |
-
raise TypeError(f"Unsupported audio data type: {type(audio_data)}")
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
-
|
159 |
audio_np = np.clip(audio_np, -1.0, 1.0)
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
if output_format.lower() in ['wav', 'flac']:
|
162 |
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
|
163 |
else:
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
segment.export(output_path, format=output_format)
|
|
|
|
|
169 |
return output_path
|
170 |
except Exception as e:
|
171 |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
|
172 |
-
cleanup_file(output_path)
|
173 |
-
raise HTTPException(status_code=500, detail="Failed to save processed audio.")
|
174 |
|
175 |
-
|
|
|
176 |
def load_audio_pydub(file_path: str) -> AudioSegment:
|
177 |
-
|
|
|
|
|
178 |
try:
|
179 |
-
audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
return audio
|
181 |
-
except CouldntDecodeError
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
def export_audio_pydub(audio: AudioSegment, format: str) -> str:
|
185 |
-
|
186 |
output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
|
187 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
188 |
try:
|
|
|
189 |
audio.export(output_path, format=format.lower())
|
190 |
return output_path
|
191 |
-
except Exception as e:
|
|
|
|
|
|
|
|
|
192 |
|
|
|
193 |
|
194 |
-
# --- Synchronous AI Inference Functions (same as before) ---
|
195 |
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
|
196 |
-
|
197 |
-
if not model: raise ValueError("Enhancement model not
|
198 |
try:
|
199 |
logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
|
200 |
-
model_device = next(model.parameters()).device
|
201 |
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
|
|
|
202 |
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
|
|
|
203 |
with torch.no_grad():
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
enhanced_audio = enhanced_tensor.squeeze(0).cpu()
|
206 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
207 |
return enhanced_audio
|
208 |
-
except Exception as e:
|
|
|
|
|
209 |
|
210 |
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
|
211 |
-
|
212 |
-
if not model: raise ValueError("Separation model not
|
213 |
if not demucs: raise RuntimeError("Demucs library missing")
|
214 |
try:
|
215 |
logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
|
216 |
model_device = next(model.parameters()).device
|
217 |
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
220 |
if audio_tensor.shape[1] != model.audio_channels:
|
221 |
-
if audio_tensor.shape[1] == 1:
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
with torch.no_grad():
|
|
|
|
|
224 |
audio_to_process = audio_tensor.squeeze(0)
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
stem_map = {name: out[i] for i, name in enumerate(model.sources)}
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
return output_stems
|
230 |
-
|
|
|
|
|
|
|
231 |
|
232 |
|
233 |
# --- Model Loading Function (Enhanced Logging) ---
|
@@ -235,29 +382,35 @@ def load_hf_models():
|
|
235 |
"""Loads AI models at startup using correct libraries."""
|
236 |
logger_load = logging.getLogger("ModelLoader") # Use specific logger
|
237 |
logger_load.setLevel(logging.INFO)
|
238 |
-
|
|
|
239 |
|
240 |
global enhancement_models, separation_models
|
241 |
if not AI_LIBS_AVAILABLE:
|
242 |
logger_load.error("Core AI libraries not available. Cannot load AI models.")
|
243 |
return
|
244 |
|
|
|
|
|
245 |
# --- Load Enhancement Model ---
|
246 |
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
|
247 |
logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
|
248 |
try:
|
249 |
-
# Log device before loading
|
250 |
logger_load.info(f"Attempting load on device: {DEVICE}")
|
|
|
|
|
|
|
251 |
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
|
252 |
source=enhancement_model_hparams,
|
|
|
253 |
run_opts={"device": DEVICE}
|
254 |
)
|
255 |
-
# Check model device after loading
|
256 |
model_device = next(enhancer.parameters()).device
|
257 |
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
|
258 |
logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
|
|
|
259 |
except Exception as e:
|
260 |
-
logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False)
|
261 |
logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
|
262 |
logger_load.warning("Enhancement features will be unavailable.")
|
263 |
|
@@ -267,28 +420,29 @@ def load_hf_models():
|
|
267 |
logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
|
268 |
try:
|
269 |
logger_load.info(f"Attempting load on device: {DEVICE}")
|
270 |
-
# This automatically handles downloading the model checkpoint
|
271 |
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
|
272 |
model_device = next(separator.parameters()).device
|
273 |
separation_models[SEPARATION_MODEL_KEY] = separator
|
274 |
logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
|
275 |
-
logger_load.info(f"Separation model sources: {separator.sources}")
|
|
|
276 |
except Exception as e:
|
277 |
logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
|
278 |
logger_load.error(f"Traceback: {traceback.format_exc()}")
|
279 |
-
logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")
|
280 |
logger_load.warning("Separation features will be unavailable.")
|
281 |
|
282 |
logger_load.info(f"--- Model loading attempts finished ---")
|
283 |
-
logger_load.info(f"
|
284 |
-
logger_load.info(f"
|
285 |
|
286 |
|
287 |
# --- FastAPI App ---
|
288 |
app = FastAPI(
|
289 |
title="AI Audio Editor API",
|
290 |
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
|
291 |
-
version="2.1.
|
292 |
)
|
293 |
|
294 |
@app.on_event("startup")
|
@@ -296,210 +450,440 @@ async def startup_event():
|
|
296 |
# Use the init logger for startup messages
|
297 |
logger_init.info("--- FastAPI Application Startup ---")
|
298 |
if AI_LIBS_AVAILABLE:
|
299 |
-
logger_init.info("AI Libraries
|
300 |
# Run blocking model load in thread
|
301 |
await asyncio.to_thread(load_hf_models)
|
302 |
-
logger_init.info("Background model loading task finished (check ModelLoader logs for details).")
|
303 |
else:
|
304 |
-
logger_init.error("AI Libraries failed to import. AI features will be disabled.")
|
305 |
-
logger_init.info("--- Startup complete ---")
|
306 |
|
307 |
# --- API Endpoints ---
|
308 |
|
309 |
@app.get("/", tags=["General"])
|
310 |
def read_root():
|
311 |
-
"""Root endpoint providing a welcome message and
|
312 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
313 |
-
|
314 |
-
|
315 |
-
if
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
return {
|
322 |
"message": "Welcome to the AI Audio Editor API.",
|
323 |
"status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
|
324 |
-
"
|
325 |
-
"
|
326 |
-
"
|
327 |
-
"ai_features": ai_features if ai_features else "None available (check startup logs)",
|
328 |
-
"notes": "Requires FFmpeg. AI features require models to load successfully at startup."
|
329 |
}
|
330 |
|
331 |
|
332 |
# --- Basic Editing Endpoints ---
|
333 |
-
|
334 |
@app.post("/trim", tags=["Basic Editing"])
|
335 |
-
async def trim_audio(
|
336 |
-
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
try:
|
340 |
-
audio = load_audio_pydub(input_path)
|
341 |
trimmed_audio = audio[start_ms:end_ms]
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
except Exception as e:
|
|
|
|
|
348 |
if output_path: cleanup_file(output_path)
|
349 |
-
|
|
|
350 |
|
351 |
@app.post("/concat", tags=["Basic Editing"])
|
352 |
-
async def concatenate_audio(
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
try:
|
356 |
-
|
357 |
-
for file in files:
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
except Exception as e:
|
368 |
-
|
|
|
369 |
if output_path: cleanup_file(output_path)
|
370 |
-
|
|
|
371 |
|
372 |
@app.post("/volume", tags=["Basic Editing"])
|
373 |
-
async def change_volume(
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
try:
|
377 |
audio = load_audio_pydub(input_path)
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
background_tasks.add_task(cleanup_file, output_path)
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
except Exception as e:
|
|
|
385 |
if output_path: cleanup_file(output_path)
|
386 |
-
|
|
|
387 |
|
388 |
@app.post("/convert", tags=["Basic Editing"])
|
389 |
-
async def convert_format(
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
try:
|
|
|
395 |
audio = load_audio_pydub(input_path)
|
396 |
-
|
|
|
|
|
|
|
397 |
background_tasks.add_task(cleanup_file, output_path)
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
except Exception as e:
|
|
|
401 |
if output_path: cleanup_file(output_path)
|
402 |
-
|
403 |
|
404 |
|
405 |
-
# --- AI Endpoints
|
406 |
|
407 |
@app.post("/enhance", tags=["AI Editing"])
|
408 |
async def enhance_speech(
|
409 |
background_tasks: BackgroundTasks,
|
410 |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
|
411 |
-
model_key
|
|
|
412 |
output_format: str = Form("wav", description="Output format (wav, flac recommended).")
|
413 |
):
|
414 |
"""Enhances speech audio using a pre-loaded SpeechBrain model."""
|
415 |
-
if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
419 |
|
420 |
-
|
421 |
-
logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
|
422 |
input_path = await save_upload_file(file, prefix="enhance_in_")
|
423 |
background_tasks.add_task(cleanup_file, input_path)
|
424 |
output_path = None
|
425 |
try:
|
|
|
426 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
|
|
427 |
logger.info("Submitting enhancement task to background thread...")
|
428 |
enhanced_audio_tensor = await asyncio.to_thread(
|
429 |
_run_enhancement_sync, loaded_model, audio_tensor, current_sr
|
430 |
)
|
431 |
logger.info("Enhancement task completed.")
|
|
|
|
|
432 |
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
|
433 |
background_tasks.add_task(cleanup_file, output_path)
|
|
|
434 |
output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
except Exception as e:
|
437 |
-
logger.error(f"
|
438 |
if output_path: cleanup_file(output_path)
|
439 |
-
|
440 |
|
441 |
|
442 |
@app.post("/separate", tags=["AI Editing"])
|
443 |
async def separate_sources(
|
444 |
background_tasks: BackgroundTasks,
|
445 |
file: UploadFile = File(..., description="Music audio file to separate into stems."),
|
446 |
-
model_key: str = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use."),
|
447 |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
|
448 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
449 |
):
|
450 |
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
|
451 |
-
if not AI_LIBS_AVAILABLE: raise HTTPException(501,"AI libraries not available.")
|
452 |
-
|
453 |
-
|
454 |
-
|
|
|
455 |
|
456 |
-
loaded_model = separation_models[
|
457 |
valid_stems = set(loaded_model.sources)
|
458 |
requested_stems = set(s.lower() for s in stems)
|
459 |
-
if not requested_stems.issubset(valid_stems):
|
460 |
-
raise HTTPException(422, f"Invalid stem(s). Model '{model_key}' provides: {valid_stems}")
|
461 |
|
462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
input_path = await save_upload_file(file, prefix="separate_in_")
|
464 |
background_tasks.add_task(cleanup_file, input_path)
|
465 |
-
stem_output_paths: Dict[str, str] = {}
|
466 |
-
zip_buffer = io.BytesIO(); zipf = None #
|
467 |
|
468 |
try:
|
|
|
469 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
|
|
470 |
logger.info("Submitting separation task to background thread...")
|
471 |
all_separated_stems_tensors = await asyncio.to_thread(
|
472 |
_run_separation_sync, loaded_model, audio_tensor, current_sr
|
473 |
)
|
474 |
-
logger.info("Separation task completed.")
|
475 |
|
|
|
|
|
476 |
zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
|
|
|
477 |
for stem_name in requested_stems:
|
478 |
if stem_name in all_separated_stems_tensors:
|
479 |
stem_tensor = all_separated_stems_tensors[stem_name]
|
480 |
-
stem_path =
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
zipf.close() # Close zip file BEFORE seeking/reading
|
488 |
-
|
489 |
|
490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
return StreamingResponse(
|
492 |
-
zip_buffer,
|
493 |
media_type="application/zip",
|
494 |
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
|
495 |
)
|
496 |
-
except
|
497 |
-
logger.error(f"
|
498 |
-
# Ensure buffer/zipfile are closed and temp files cleaned up on error
|
499 |
if zipf: zipf.close() # Ensure zipfile is closed
|
500 |
if zip_buffer: zip_buffer.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
for path in stem_output_paths.values(): cleanup_file(path)
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
ch = logging.StreamHandler()
|
26 |
ch.setLevel(logging.INFO)
|
27 |
ch.setFormatter(formatter)
|
28 |
+
# Avoid adding handler multiple times if script reloads
|
29 |
+
if not logger_init.handlers:
|
30 |
+
logger_init.addHandler(ch)
|
31 |
|
32 |
+
AI_LIBS_AVAILABLE = False
|
33 |
try:
|
34 |
logger_init.info("Importing torch...")
|
35 |
import torch
|
|
|
48 |
AI_LIBS_AVAILABLE = True
|
49 |
except ImportError as e:
|
50 |
logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
|
51 |
+
logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed correctly.")
|
52 |
logger_init.error("AI features will be unavailable.")
|
53 |
+
# Define placeholders so the rest of the code doesn't break completely on import error
|
54 |
torch = None
|
55 |
sf = None
|
56 |
np = None
|
57 |
librosa = None
|
58 |
speechbrain = None
|
59 |
demucs = None
|
|
|
60 |
|
61 |
# --- Configuration & Setup ---
|
62 |
TEMP_DIR = tempfile.gettempdir()
|
63 |
+
# Attempt to create temp dir if it doesn't exist (useful in some environments)
|
64 |
+
try:
|
65 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
66 |
+
except OSError as e:
|
67 |
+
logger_init.error(f"Could not create temporary directory {TEMP_DIR}: {e}")
|
68 |
+
# Fallback or raise an error depending on desired behavior
|
69 |
+
TEMP_DIR = "." # Use current directory as fallback (less ideal)
|
70 |
+
logger_init.warning(f"Using current directory '{TEMP_DIR}' for temporary files.")
|
71 |
+
|
72 |
|
73 |
# Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
|
74 |
+
# This logger will be used by endpoint handlers
|
75 |
+
logger = logging.getLogger(__name__)
|
76 |
|
77 |
# --- Global Variables for Loaded Models ---
|
78 |
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
|
79 |
+
# Choose a default Demucs model (htdemucs is good quality)
|
80 |
+
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one
|
81 |
|
82 |
enhancement_models: Dict[str, Any] = {}
|
83 |
separation_models: Dict[str, Any] = {}
|
84 |
|
85 |
+
# Target sampling rates (confirm from model specifics if necessary)
|
86 |
+
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
|
87 |
+
DEMUCS_SR = 44100 # Demucs default is 44.1kHz
|
88 |
|
89 |
# --- Device Selection ---
|
90 |
+
if AI_LIBS_AVAILABLE and torch:
|
91 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
92 |
+
logger_init.info(f"Selected device for AI models: {DEVICE}")
|
93 |
else:
|
94 |
+
DEVICE = "cpu" # Fallback if torch failed import
|
95 |
+
logger_init.info("Torch not available or AI libs failed import, defaulting device to CPU.")
|
96 |
|
97 |
|
98 |
+
# --- Helper Functions ---
|
99 |
+
|
100 |
def cleanup_file(file_path: str):
|
101 |
"""Safely remove a file."""
|
102 |
try:
|
103 |
+
if file_path and isinstance(file_path, str) and os.path.exists(file_path):
|
104 |
os.remove(file_path)
|
105 |
# logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
|
106 |
except Exception as e:
|
107 |
+
# Log error but don't crash the cleanup process for other files
|
108 |
logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)
|
109 |
|
110 |
async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
|
111 |
"""Saves an uploaded file to a temporary location and returns the path."""
|
112 |
+
if not upload_file or not upload_file.filename:
|
113 |
+
raise HTTPException(status_code=400, detail="Invalid file upload object.")
|
114 |
+
|
115 |
_, file_extension = os.path.splitext(upload_file.filename)
|
116 |
+
# Default to .wav if no extension, as it's widely compatible for loading
|
117 |
if not file_extension: file_extension = ".wav"
|
118 |
temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
|
119 |
+
|
120 |
try:
|
121 |
+
logger.debug(f"Attempting to save uploaded file to: {temp_file_path}")
|
122 |
with open(temp_file_path, "wb") as buffer:
|
123 |
+
# Read chunk by chunk for large files
|
124 |
+
while content := await upload_file.read(1024 * 1024): # 1MB chunks
|
125 |
+
buffer.write(content)
|
126 |
+
logger.info(f"Saved uploaded file '{upload_file.filename}' ({upload_file.content_type}) to temp path: {temp_file_path}")
|
127 |
return temp_file_path
|
128 |
except Exception as e:
|
129 |
+
logger.error(f"Failed to save uploaded file '{upload_file.filename}' to {temp_file_path}: {e}", exc_info=True)
|
130 |
+
cleanup_file(temp_file_path) # Attempt cleanup if saving failed
|
131 |
raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
|
132 |
finally:
|
133 |
+
# Ensure file is closed even if saving fails mid-way
|
134 |
+
try:
|
135 |
+
await upload_file.close()
|
136 |
+
except Exception:
|
137 |
+
pass # Ignore errors during close if already failed
|
138 |
+
|
139 |
+
|
140 |
+
# --- Audio Loading/Saving Functions ---
|
141 |
|
|
|
142 |
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
|
143 |
+
"""Loads audio using soundfile, converts to mono float32 Torch tensor, optionally resamples."""
|
144 |
+
if not AI_LIBS_AVAILABLE:
|
145 |
+
raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
|
146 |
+
if not os.path.exists(file_path):
|
147 |
+
raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found at {file_path}")
|
148 |
+
|
149 |
try:
|
150 |
audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
|
151 |
+
logger.info(f"Loaded '{os.path.basename(file_path)}' - SR={orig_sr}, Shape={audio.shape}, dtype={audio.dtype}")
|
152 |
+
|
153 |
+
# Ensure mono
|
154 |
+
if audio.ndim > 1:
|
155 |
+
# Check which dimension is smaller (likely channels)
|
156 |
+
channel_dim = np.argmin(audio.shape)
|
157 |
+
if audio.shape[channel_dim] > 1 and audio.shape[channel_dim] < 10: # Heuristic: <10 channels
|
158 |
+
logger.info(f"Detected {audio.shape[channel_dim]} channels. Converting to mono by averaging axis {channel_dim}.")
|
159 |
+
audio = np.mean(audio, axis=channel_dim)
|
160 |
+
else: # Fallback or if shape is ambiguous (e.g., very short stereo)
|
161 |
+
logger.warning(f"Audio has shape {audio.shape}. Taking first channel/element assuming mono or channel-first.")
|
162 |
+
audio = audio[0] if channel_dim == 0 else audio[:, 0] # Select first index of the likely channel dimension
|
163 |
+
|
164 |
+
logger.debug(f"Shape after mono conversion: {audio.shape}")
|
165 |
+
|
166 |
|
167 |
+
# Ensure it's now 1D
|
168 |
+
audio = audio.flatten()
|
169 |
+
|
170 |
+
# Convert numpy array to torch tensor
|
171 |
audio_tensor = torch.from_numpy(audio).float()
|
172 |
+
|
173 |
+
# Resample if necessary using librosa
|
174 |
+
current_sr = orig_sr
|
175 |
if target_sr and orig_sr != target_sr:
|
176 |
+
if librosa is None: raise RuntimeError("Librosa missing for resampling")
|
177 |
logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
|
178 |
+
# Librosa works on numpy
|
179 |
audio_np = audio_tensor.numpy()
|
180 |
+
resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr, res_type='kaiser_best') # Specify resampling type
|
181 |
audio_tensor = torch.from_numpy(resampled_audio_np).float()
|
182 |
current_sr = target_sr
|
183 |
+
logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
|
184 |
+
|
185 |
+
# Ensure tensor is on the correct device
|
186 |
return audio_tensor.to(DEVICE), current_sr
|
187 |
+
|
188 |
+
except sf.SoundFileError as sf_err:
|
189 |
+
logger.error(f"SoundFileError loading {file_path}: {sf_err}", exc_info=True)
|
190 |
+
cleanup_file(file_path)
|
191 |
+
raise HTTPException(status_code=415, detail=f"Could not decode audio file: {os.path.basename(file_path)}. Unsupported format or corrupt file. Error: {sf_err}")
|
192 |
except Exception as e:
|
193 |
+
logger.error(f"Unexpected error loading/processing audio file {file_path} for AI: {e}", exc_info=True)
|
194 |
cleanup_file(file_path)
|
195 |
+
raise HTTPException(status_code=500, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Check server logs.")
|
196 |
|
197 |
|
198 |
def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
|
199 |
"""Saves audio data (Tensor or NumPy array) to a temporary file."""
|
200 |
+
if not AI_LIBS_AVAILABLE:
|
201 |
+
raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
|
202 |
+
|
203 |
output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
|
204 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
205 |
try:
|
206 |
+
logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")
|
207 |
+
|
208 |
+
# Convert tensor to numpy array if needed
|
209 |
if isinstance(audio_data, torch.Tensor):
|
210 |
+
logger.debug("Converting output tensor to NumPy array.")
|
211 |
+
# Ensure tensor is on CPU before converting to numpy
|
212 |
audio_np = audio_data.detach().cpu().numpy()
|
213 |
elif isinstance(audio_data, np.ndarray):
|
214 |
audio_np = audio_data
|
215 |
else:
|
216 |
+
raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")
|
217 |
+
|
218 |
+
# Ensure data is float32
|
219 |
+
if audio_np.dtype != np.float32:
|
220 |
+
logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
|
221 |
+
audio_np = audio_np.astype(np.float32)
|
222 |
|
223 |
+
# Clip values to avoid potential issues with formats expecting [-1, 1]
|
224 |
audio_np = np.clip(audio_np, -1.0, 1.0)
|
225 |
|
226 |
+
# Ensure audio is 1D (mono) before saving with soundfile or pydub conversion
|
227 |
+
if audio_np.ndim > 1:
|
228 |
+
logger.warning(f"Output audio data has {audio_np.ndim} dimensions, attempting to flatten or take first dimension.")
|
229 |
+
# Try averaging channels if shape suggests stereo/multi-channel
|
230 |
+
channel_dim = np.argmin(audio_np.shape)
|
231 |
+
if audio_np.shape[channel_dim] > 1 and audio_np.shape[channel_dim] < 10:
|
232 |
+
audio_np = np.mean(audio_np, axis=channel_dim)
|
233 |
+
else: # Otherwise just flatten
|
234 |
+
audio_np = audio_np.flatten()
|
235 |
+
|
236 |
+
|
237 |
+
# Use soundfile (preferred for wav/flac)
|
238 |
if output_format.lower() in ['wav', 'flac']:
|
239 |
sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
|
240 |
else:
|
241 |
+
# For lossy formats, use pydub
|
242 |
+
logger.debug(f"Using pydub to export to lossy format: {output_format}")
|
243 |
+
# Scale float32 [-1, 1] to int16 for pydub
|
244 |
+
audio_int16 = (audio_np * 32767).astype(np.int16)
|
245 |
+
segment = AudioSegment(
|
246 |
+
audio_int16.tobytes(),
|
247 |
+
frame_rate=sampling_rate,
|
248 |
+
sample_width=audio_int16.dtype.itemsize,
|
249 |
+
channels=1 # Assuming mono after processing above
|
250 |
+
)
|
251 |
+
# Pydub might need explicit ffmpeg path in some envs
|
252 |
+
# AudioSegment.converter = "/path/to/ffmpeg" # Uncomment and set path if needed
|
253 |
segment.export(output_path, format=output_format)
|
254 |
+
|
255 |
+
logger.info(f"Successfully saved AI audio to {output_path}")
|
256 |
return output_path
|
257 |
except Exception as e:
|
258 |
logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
|
259 |
+
cleanup_file(output_path) # Attempt cleanup on saving failure
|
260 |
+
raise HTTPException(status_code=500, detail=f"Failed to save processed audio to format '{output_format}'.")
|
261 |
|
262 |
+
|
263 |
+
# --- Pydub Loading/Exporting (for basic edits) ---
|
264 |
def load_audio_pydub(file_path: str) -> AudioSegment:
|
265 |
+
"""Loads an audio file using pydub."""
|
266 |
+
if not os.path.exists(file_path):
|
267 |
+
raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found (pydub) at {file_path}")
|
268 |
try:
|
269 |
+
logger.debug(f"Loading audio with pydub: {file_path}")
|
270 |
+
# Explicitly provide format if possible, helps pydub sometimes
|
271 |
+
file_ext = os.path.splitext(file_path)[1][1:].lower()
|
272 |
+
if file_ext:
|
273 |
+
audio = AudioSegment.from_file(file_path, format=file_ext)
|
274 |
+
else:
|
275 |
+
audio = AudioSegment.from_file(file_path) # Let pydub detect
|
276 |
+
logger.info(f"Loaded audio using pydub from: {file_path}")
|
277 |
return audio
|
278 |
+
except CouldntDecodeError as e:
|
279 |
+
logger.warning(f"Pydub CouldntDecodeError for {file_path}: {e}")
|
280 |
+
cleanup_file(file_path)
|
281 |
+
raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}")
|
282 |
+
except Exception as e:
|
283 |
+
logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True)
|
284 |
+
cleanup_file(file_path)
|
285 |
+
raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}")
|
286 |
|
287 |
def export_audio_pydub(audio: AudioSegment, format: str) -> str:
|
288 |
+
"""Exports a Pydub AudioSegment to a temporary file and returns the path."""
|
289 |
output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
|
290 |
output_path = os.path.join(TEMP_DIR, output_filename)
|
291 |
try:
|
292 |
+
logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}")
|
293 |
audio.export(output_path, format=format.lower())
|
294 |
return output_path
|
295 |
+
except Exception as e:
|
296 |
+
logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True)
|
297 |
+
cleanup_file(output_path) # Cleanup if export failed
|
298 |
+
raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.")
|
299 |
+
|
300 |
|
301 |
+
# --- Synchronous AI Inference Functions ---
|
302 |
|
|
|
303 |
def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
|
304 |
+
"""Synchronous wrapper for SpeechBrain enhancement model inference."""
|
305 |
+
if not AI_LIBS_AVAILABLE or not model: raise ValueError("Enhancement model/libs not available")
|
306 |
try:
|
307 |
logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
|
308 |
+
model_device = next(model.parameters()).device # Check model's current device
|
309 |
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
|
310 |
+
# Add batch dimension if model expects it (most do)
|
311 |
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)
|
312 |
+
|
313 |
with torch.no_grad():
|
314 |
+
# Check if model expects lengths parameter
|
315 |
+
enhance_method = getattr(model, "enhance_batch", getattr(model, "forward", None))
|
316 |
+
if "lengths" in enhance_method.__code__.co_varnames:
|
317 |
+
enhanced_tensor = enhance_method(audio_tensor, lengths=torch.tensor([audio_tensor.shape[-1]]).to(model_device))
|
318 |
+
else:
|
319 |
+
enhanced_tensor = enhance_method(audio_tensor)
|
320 |
+
|
321 |
+
|
322 |
+
# Remove batch dimension from output before returning, move back to CPU
|
323 |
enhanced_audio = enhanced_tensor.squeeze(0).cpu()
|
324 |
logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
|
325 |
return enhanced_audio
|
326 |
+
except Exception as e:
|
327 |
+
logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
|
328 |
+
raise # Re-raise to be caught by the async wrapper
|
329 |
|
330 |
def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
|
331 |
+
"""Synchronous wrapper for Demucs source separation model inference."""
|
332 |
+
if not AI_LIBS_AVAILABLE or not model: raise ValueError("Separation model/libs not available")
|
333 |
if not demucs: raise RuntimeError("Demucs library missing")
|
334 |
try:
|
335 |
logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
|
336 |
model_device = next(model.parameters()).device
|
337 |
if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
|
338 |
+
|
339 |
+
# Demucs expects audio as (batch, channels, samples)
|
340 |
+
if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
|
341 |
+
elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)
|
342 |
+
|
343 |
+
# Repeat channel if model expects stereo but input is mono
|
344 |
if audio_tensor.shape[1] != model.audio_channels:
|
345 |
+
if audio_tensor.shape[1] == 1:
|
346 |
+
logger.info(f"Model expects {model.audio_channels} channels, input is mono. Repeating channel.")
|
347 |
+
audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
|
348 |
+
else:
|
349 |
+
raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")
|
350 |
+
|
351 |
+
logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")
|
352 |
+
|
353 |
with torch.no_grad():
|
354 |
+
# Use demucs.apply.apply_model for handling chunking etc.
|
355 |
+
# Requires input shape (channels, samples) - process first batch item
|
356 |
audio_to_process = audio_tensor.squeeze(0)
|
357 |
+
# Note: shifts=1, split=True are common defaults for quality
|
358 |
+
out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25, progress=False) # Disable progress bar in logs
|
359 |
+
# Output shape (stems, channels, samples)
|
360 |
+
|
361 |
+
logger.debug(f"Raw separated sources tensor shape: {out.shape}")
|
362 |
+
|
363 |
+
# Map stems based on the model's sources list
|
364 |
stem_map = {name: out[i] for i, name in enumerate(model.sources)}
|
365 |
+
|
366 |
+
# Convert back to mono for simplicity (average channels) and move to CPU
|
367 |
+
output_stems = {}
|
368 |
+
for name, data in stem_map.items():
|
369 |
+
# Average channels, detach, move to CPU
|
370 |
+
output_stems[name] = data.mean(dim=0).detach().cpu()
|
371 |
+
|
372 |
+
logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
|
373 |
return output_stems
|
374 |
+
|
375 |
+
except Exception as e:
|
376 |
+
logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
|
377 |
+
raise
|
378 |
|
379 |
|
380 |
# --- Model Loading Function (Enhanced Logging) ---
|
|
|
382 |
"""Loads AI models at startup using correct libraries."""
|
383 |
logger_load = logging.getLogger("ModelLoader") # Use specific logger
|
384 |
logger_load.setLevel(logging.INFO)
|
385 |
+
# Ensure handler is attached if logger is newly created
|
386 |
+
if not logger_load.handlers and ch: logger_load.addHandler(ch)
|
387 |
|
388 |
global enhancement_models, separation_models
|
389 |
if not AI_LIBS_AVAILABLE:
|
390 |
logger_load.error("Core AI libraries not available. Cannot load AI models.")
|
391 |
return
|
392 |
|
393 |
+
load_success_flags = {"enhancement": False, "separation": False}
|
394 |
+
|
395 |
# --- Load Enhancement Model ---
|
396 |
enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
|
397 |
logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
|
398 |
try:
|
|
|
399 |
logger_load.info(f"Attempting load on device: {DEVICE}")
|
400 |
+
# Consider adding savedir if cache issues arise in HF Spaces
|
401 |
+
# savedir_sb = os.path.join(TEMP_DIR, "speechbrain_models")
|
402 |
+
# os.makedirs(savedir_sb, exist_ok=True)
|
403 |
enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
|
404 |
source=enhancement_model_hparams,
|
405 |
+
# savedir=savedir_sb,
|
406 |
run_opts={"device": DEVICE}
|
407 |
)
|
|
|
408 |
model_device = next(enhancer.parameters()).device
|
409 |
enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
|
410 |
logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
|
411 |
+
load_success_flags["enhancement"] = True
|
412 |
except Exception as e:
|
413 |
+
logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False)
|
414 |
logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
|
415 |
logger_load.warning("Enhancement features will be unavailable.")
|
416 |
|
|
|
420 |
logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
|
421 |
try:
|
422 |
logger_load.info(f"Attempting load on device: {DEVICE}")
|
423 |
+
# This automatically handles downloading the model checkpoint via demucs package
|
424 |
separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
|
425 |
model_device = next(separator.parameters()).device
|
426 |
separation_models[SEPARATION_MODEL_KEY] = separator
|
427 |
logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
|
428 |
+
logger_load.info(f"Separation model available sources: {separator.sources}")
|
429 |
+
load_success_flags["separation"] = True
|
430 |
except Exception as e:
|
431 |
logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
|
432 |
logger_load.error(f"Traceback: {traceback.format_exc()}")
|
433 |
+
logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs). Check resource limits (RAM).")
|
434 |
logger_load.warning("Separation features will be unavailable.")
|
435 |
|
436 |
logger_load.info(f"--- Model loading attempts finished ---")
|
437 |
+
logger_load.info(f"Enhancement Model Loaded: {load_success_flags['enhancement']}")
|
438 |
+
logger_load.info(f"Separation Model Loaded: {load_success_flags['separation']}")
|
439 |
|
440 |
|
441 |
# --- FastAPI App ---
|
442 |
app = FastAPI(
|
443 |
title="AI Audio Editor API",
|
444 |
description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
|
445 |
+
version="2.1.2", # Incremented version
|
446 |
)
|
447 |
|
448 |
@app.on_event("startup")
|
|
|
450 |
# Use the init logger for startup messages
|
451 |
logger_init.info("--- FastAPI Application Startup ---")
|
452 |
if AI_LIBS_AVAILABLE:
|
453 |
+
logger_init.info("AI Libraries imported successfully. Loading models in background thread...")
|
454 |
# Run blocking model load in thread
|
455 |
await asyncio.to_thread(load_hf_models)
|
456 |
+
logger_init.info("Background model loading task finished (check ModelLoader logs above for details).")
|
457 |
else:
|
458 |
+
logger_init.error("AI Libraries failed to import during init. AI features will be disabled.")
|
459 |
+
logger_init.info("--- Startup sequence complete ---")
|
460 |
|
461 |
# --- API Endpoints ---
|
462 |
|
463 |
@app.get("/", tags=["General"])
|
464 |
def read_root():
|
465 |
+
"""Root endpoint providing a welcome message and status of loaded models."""
|
466 |
features = ["/trim", "/concat", "/volume", "/convert"]
|
467 |
+
ai_features_status = {}
|
468 |
+
|
469 |
+
if AI_LIBS_AVAILABLE:
|
470 |
+
if enhancement_models:
|
471 |
+
ai_features_status[ENHANCEMENT_MODEL_KEY] = "Loaded"
|
472 |
+
else:
|
473 |
+
ai_features_status[ENHANCEMENT_MODEL_KEY] = "Failed to load (check startup logs)"
|
474 |
+
|
475 |
+
if separation_models:
|
476 |
+
model = separation_models.get(SEPARATION_MODEL_KEY)
|
477 |
+
sources_str = ', '.join(model.sources) if model else 'N/A'
|
478 |
+
ai_features_status[SEPARATION_MODEL_KEY] = f"Loaded (Sources: {sources_str})"
|
479 |
+
else:
|
480 |
+
ai_features_status[SEPARATION_MODEL_KEY] = "Failed to load (check startup logs)"
|
481 |
+
else:
|
482 |
+
ai_features_status["AI Status"] = "Libraries Failed Import"
|
483 |
+
|
484 |
|
485 |
return {
|
486 |
"message": "Welcome to the AI Audio Editor API.",
|
487 |
"status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
|
488 |
+
"ai_models_status": ai_features_status,
|
489 |
+
"basic_endpoints": features,
|
490 |
+
"notes": "Requires FFmpeg. AI features require successful model loading at startup."
|
|
|
|
|
491 |
}
|
492 |
|
493 |
|
494 |
# --- Basic Editing Endpoints ---
|
495 |
+
|
496 |
@app.post("/trim", tags=["Basic Editing"])
|
497 |
+
async def trim_audio(
|
498 |
+
background_tasks: BackgroundTasks,
|
499 |
+
file: UploadFile = File(..., description="Audio file to trim."),
|
500 |
+
start_ms: int = Form(..., ge=0, description="Start time in milliseconds."),
|
501 |
+
end_ms: int = Form(..., gt=0, description="End time in milliseconds.") # Ensure end > 0
|
502 |
+
):
|
503 |
+
"""Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub."""
|
504 |
+
if end_ms <= start_ms:
|
505 |
+
raise HTTPException(status_code=422, detail="End time (end_ms) must be greater than start time (start_ms).")
|
506 |
+
|
507 |
+
logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
|
508 |
+
input_path = await save_upload_file(file, prefix="trim_in_")
|
509 |
+
# Schedule cleanup immediately after saving, even if loading fails later
|
510 |
+
background_tasks.add_task(cleanup_file, input_path)
|
511 |
+
output_path = None # Define before try block
|
512 |
+
|
513 |
try:
|
514 |
+
audio = load_audio_pydub(input_path) # Can raise HTTPException
|
515 |
trimmed_audio = audio[start_ms:end_ms]
|
516 |
+
logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")
|
517 |
+
|
518 |
+
# Determine original format for export
|
519 |
+
original_format = os.path.splitext(file.filename)[1][1:].lower()
|
520 |
+
# Use mp3 as default only if no extension or if it's 'tmp' etc.
|
521 |
+
if not original_format or len(original_format) > 5: # Basic check for valid extension length
|
522 |
+
original_format = "mp3"
|
523 |
+
logger.warning(f"Using default export format 'mp3' for input '{file.filename}'")
|
524 |
+
|
525 |
+
output_path = export_audio_pydub(trimmed_audio, original_format) # Can raise HTTPException
|
526 |
+
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
|
527 |
+
|
528 |
+
# Create a more informative filename
|
529 |
+
output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}"
|
530 |
+
|
531 |
+
return FileResponse(
|
532 |
+
path=output_path,
|
533 |
+
media_type=f"audio/{original_format}", # Best guess for media type
|
534 |
+
filename=output_filename
|
535 |
+
)
|
536 |
+
except HTTPException as http_exc:
|
537 |
+
# If load/export raised HTTPException, re-raise it
|
538 |
+
# Cleanup might have already been scheduled, background tasks handle errors
|
539 |
+
logger.error(f"HTTP Exception during trim: {http_exc.detail}")
|
540 |
+
if output_path: cleanup_file(output_path) # Try immediate cleanup if output exists
|
541 |
+
raise http_exc
|
542 |
except Exception as e:
|
543 |
+
# Catch other unexpected errors during trimming logic
|
544 |
+
logger.error(f"Unexpected error during trim operation: {e}", exc_info=True)
|
545 |
if output_path: cleanup_file(output_path)
|
546 |
+
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during trimming: {str(e)}")
|
547 |
+
|
548 |
|
549 |
@app.post("/concat", tags=["Basic Editing"])
|
550 |
+
async def concatenate_audio(
|
551 |
+
background_tasks: BackgroundTasks,
|
552 |
+
files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
|
553 |
+
output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
|
554 |
+
):
|
555 |
+
"""Concatenates two or more audio files sequentially using Pydub."""
|
556 |
+
if len(files) < 2:
|
557 |
+
raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")
|
558 |
+
|
559 |
+
logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
|
560 |
+
input_paths = [] # Keep track of all saved input file paths
|
561 |
+
output_path = None # Define before try block
|
562 |
+
|
563 |
try:
|
564 |
+
combined_audio: Optional[AudioSegment] = None
|
565 |
+
for i, file in enumerate(files):
|
566 |
+
if not file or not file.filename:
|
567 |
+
logger.warning(f"Skipping invalid file upload at index {i}.")
|
568 |
+
continue # Skip potentially empty file entries
|
569 |
+
|
570 |
+
input_path = await save_upload_file(file, prefix=f"concat_{i}_in_")
|
571 |
+
input_paths.append(input_path)
|
572 |
+
# Schedule cleanup for this specific input file immediately
|
573 |
+
background_tasks.add_task(cleanup_file, input_path)
|
574 |
+
|
575 |
+
try:
|
576 |
+
audio = load_audio_pydub(input_path)
|
577 |
+
if combined_audio is None:
|
578 |
+
combined_audio = audio
|
579 |
+
logger.info(f"Starting concatenation with '{file.filename}' ({len(combined_audio)}ms)")
|
580 |
+
else:
|
581 |
+
logger.info(f"Adding '{file.filename}' ({len(audio)}ms)")
|
582 |
+
combined_audio += audio
|
583 |
+
except HTTPException as load_exc:
|
584 |
+
# Log error but continue trying to load other files if possible
|
585 |
+
logger.error(f"Failed to load file '{file.filename}' for concatenation: {load_exc.detail}. Skipping this file.")
|
586 |
+
except Exception as load_exc:
|
587 |
+
logger.error(f"Unexpected error loading file '{file.filename}' for concatenation: {load_exc}. Skipping this file.", exc_info=True)
|
588 |
+
|
589 |
+
|
590 |
+
if combined_audio is None:
|
591 |
+
raise HTTPException(status_code=400, detail="No valid audio files could be loaded and combined.")
|
592 |
+
|
593 |
+
logger.info(f"Final concatenated audio length: {len(combined_audio)}ms")
|
594 |
+
|
595 |
+
output_path = export_audio_pydub(combined_audio, output_format) # Can raise HTTPException
|
596 |
+
background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup
|
597 |
+
|
598 |
+
# Determine a reasonable output filename
|
599 |
+
first_valid_filename = files[0].filename if files and files[0] else "audio"
|
600 |
+
first_filename_base = os.path.splitext(first_valid_filename)[0]
|
601 |
+
output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"
|
602 |
+
|
603 |
+
return FileResponse(
|
604 |
+
path=output_path,
|
605 |
+
media_type=f"audio/{output_format}",
|
606 |
+
filename=output_filename
|
607 |
+
)
|
608 |
+
except HTTPException as http_exc:
|
609 |
+
# If load/export raised HTTPException, re-raise it
|
610 |
+
logger.error(f"HTTP Exception during concat: {http_exc.detail}")
|
611 |
+
# Cleanup for output path, inputs are handled by background tasks
|
612 |
+
if output_path: cleanup_file(output_path)
|
613 |
+
raise http_exc
|
614 |
except Exception as e:
|
615 |
+
# Catch other unexpected errors during combining logic
|
616 |
+
logger.error(f"Unexpected error during concat operation: {e}", exc_info=True)
|
617 |
if output_path: cleanup_file(output_path)
|
618 |
+
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during concatenation: {str(e)}")
|
619 |
+
|
620 |
|
621 |
@app.post("/volume", tags=["Basic Editing"])
|
622 |
+
async def change_volume(
|
623 |
+
background_tasks: BackgroundTasks,
|
624 |
+
file: UploadFile = File(..., description="Audio file to adjust volume for."),
|
625 |
+
change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.")
|
626 |
+
):
|
627 |
+
"""Adjusts the volume of an audio file by a specified decibel amount using Pydub."""
|
628 |
+
logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
|
629 |
+
input_path = await save_upload_file(file, prefix="volume_in_")
|
630 |
+
background_tasks.add_task(cleanup_file, input_path)
|
631 |
+
output_path = None
|
632 |
try:
|
633 |
audio = load_audio_pydub(input_path)
|
634 |
+
# Check for potential silence before applying gain
|
635 |
+
if audio.dBFS == -float('inf'):
|
636 |
+
logger.warning(f"Input file '{file.filename}' appears to be silent. Applying volume change may have no effect.")
|
637 |
+
adjusted_audio = audio + change_db
|
638 |
+
logger.info(f"Volume adjusted by {change_db}dB.")
|
639 |
+
|
640 |
+
original_format = os.path.splitext(file.filename)[1][1:].lower()
|
641 |
+
if not original_format or len(original_format) > 5: original_format = "mp3"
|
642 |
+
|
643 |
+
output_path = export_audio_pydub(adjusted_audio, original_format)
|
644 |
background_tasks.add_task(cleanup_file, output_path)
|
645 |
+
|
646 |
+
# Create filename
|
647 |
+
sign = "+" if change_db >= 0 else ""
|
648 |
+
output_filename=f"volume_{sign}{change_db}dB_{os.path.splitext(file.filename)[0]}.{original_format}"
|
649 |
+
|
650 |
+
return FileResponse(
|
651 |
+
path=output_path,
|
652 |
+
media_type=f"audio/{original_format}",
|
653 |
+
filename=output_filename
|
654 |
+
)
|
655 |
+
except HTTPException as http_exc:
|
656 |
+
logger.error(f"HTTP Exception during volume change: {http_exc.detail}")
|
657 |
+
if output_path: cleanup_file(output_path)
|
658 |
+
raise http_exc
|
659 |
except Exception as e:
|
660 |
+
logger.error(f"Unexpected error during volume operation: {e}", exc_info=True)
|
661 |
if output_path: cleanup_file(output_path)
|
662 |
+
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during volume adjustment: {str(e)}")
|
663 |
+
|
664 |
|
665 |
@app.post("/convert", tags=["Basic Editing"])
|
666 |
+
async def convert_format(
|
667 |
+
background_tasks: BackgroundTasks,
|
668 |
+
file: UploadFile = File(..., description="Audio file to convert."),
|
669 |
+
output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
|
670 |
+
):
|
671 |
+
"""Converts an audio file to a different format using Pydub."""
|
672 |
+
# Define allowed formats explicitly
|
673 |
+
allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus', 'wma', 'aiff'} # Expanded list
|
674 |
+
output_format_lower = output_format.lower()
|
675 |
+
if output_format_lower not in allowed_formats:
|
676 |
+
raise HTTPException(status_code=422, detail=f"Invalid output format '{output_format}'. Allowed: {', '.join(sorted(list(allowed_formats)))}")
|
677 |
+
|
678 |
+
logger.info(f"Convert request: file='{file.filename}', output_format='{output_format_lower}'")
|
679 |
+
input_path = await save_upload_file(file, prefix="convert_in_")
|
680 |
+
background_tasks.add_task(cleanup_file, input_path)
|
681 |
+
output_path = None
|
682 |
try:
|
683 |
+
# Load using pydub, which handles many input formats
|
684 |
audio = load_audio_pydub(input_path)
|
685 |
+
logger.info(f"Successfully loaded '{file.filename}' for conversion.")
|
686 |
+
|
687 |
+
# Export using pydub
|
688 |
+
output_path = export_audio_pydub(audio, output_format_lower)
|
689 |
background_tasks.add_task(cleanup_file, output_path)
|
690 |
+
logger.info(f"Successfully exported to {output_format_lower}")
|
691 |
+
|
692 |
+
# Construct new filename
|
693 |
+
filename_base = os.path.splitext(file.filename)[0]
|
694 |
+
output_filename = f"{filename_base}_converted.{output_format_lower}"
|
695 |
+
|
696 |
+
# Determine media type (MIME type) - might need refinement for less common types
|
697 |
+
media_type_map = {
|
698 |
+
'mp3': 'audio/mpeg', 'wav': 'audio/wav', 'ogg': 'audio/ogg',
|
699 |
+
'flac': 'audio/flac', 'aac': 'audio/aac', 'm4a': 'audio/mp4', # m4a often uses mp4 container
|
700 |
+
'opus': 'audio/opus', 'wma':'audio/x-ms-wma', 'aiff':'audio/aiff'
|
701 |
+
}
|
702 |
+
media_type = media_type_map.get(output_format_lower, 'application/octet-stream') # Default binary if unknown
|
703 |
+
|
704 |
+
return FileResponse(
|
705 |
+
path=output_path,
|
706 |
+
media_type=media_type,
|
707 |
+
filename=output_filename
|
708 |
+
)
|
709 |
+
except HTTPException as http_exc:
|
710 |
+
logger.error(f"HTTP Exception during conversion: {http_exc.detail}")
|
711 |
+
if output_path: cleanup_file(output_path)
|
712 |
+
raise http_exc
|
713 |
except Exception as e:
|
714 |
+
logger.error(f"Unexpected error during convert operation: {e}", exc_info=True)
|
715 |
if output_path: cleanup_file(output_path)
|
716 |
+
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during format conversion: {str(e)}")
|
717 |
|
718 |
|
719 |
+
# --- AI Endpoints ---
|
720 |
|
721 |
@app.post("/enhance", tags=["AI Editing"])
|
722 |
async def enhance_speech(
|
723 |
background_tasks: BackgroundTasks,
|
724 |
file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
|
725 |
+
# Keep model_key optional for now, assumes default if only one loaded
|
726 |
+
model_key: Optional[str] = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use (defaults to primary)."),
|
727 |
output_format: str = Form("wav", description="Output format (wav, flac recommended).")
|
728 |
):
|
729 |
"""Enhances speech audio using a pre-loaded SpeechBrain model."""
|
730 |
+
if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
731 |
+
# Use the provided key or the default
|
732 |
+
actual_model_key = model_key or ENHANCEMENT_MODEL_KEY
|
733 |
+
if actual_model_key not in enhancement_models:
|
734 |
+
logger.error(f"Enhancement model key '{actual_model_key}' requested but model not loaded.")
|
735 |
+
raise HTTPException(status_code=503, detail=f"Enhancement model '{actual_model_key}' is not loaded or available. Check server startup logs.")
|
736 |
+
|
737 |
+
loaded_model = enhancement_models[actual_model_key]
|
738 |
|
739 |
+
logger.info(f"Enhance request: file='{file.filename}', model='{actual_model_key}', format='{output_format}'")
|
|
|
740 |
input_path = await save_upload_file(file, prefix="enhance_in_")
|
741 |
background_tasks.add_task(cleanup_file, input_path)
|
742 |
output_path = None
|
743 |
try:
|
744 |
+
# Load audio as tensor, ensure correct SR (16kHz)
|
745 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)
|
746 |
+
|
747 |
logger.info("Submitting enhancement task to background thread...")
|
748 |
enhanced_audio_tensor = await asyncio.to_thread(
|
749 |
_run_enhancement_sync, loaded_model, audio_tensor, current_sr
|
750 |
)
|
751 |
logger.info("Enhancement task completed.")
|
752 |
+
|
753 |
+
# Save the result (tensor output from enhancer at 16kHz)
|
754 |
output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
|
755 |
background_tasks.add_task(cleanup_file, output_path)
|
756 |
+
|
757 |
output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
|
758 |
+
media_type = f"audio/{output_format}" # Basic media type
|
759 |
+
return FileResponse(path=output_path, media_type=media_type, filename=output_filename)
|
760 |
+
|
761 |
+
except HTTPException as http_exc:
|
762 |
+
logger.error(f"HTTP Exception during enhancement: {http_exc.detail}")
|
763 |
+
if output_path: cleanup_file(output_path)
|
764 |
+
raise http_exc
|
765 |
except Exception as e:
|
766 |
+
logger.error(f"Unexpected error during enhancement operation: {e}", exc_info=True)
|
767 |
if output_path: cleanup_file(output_path)
|
768 |
+
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during enhancement: {str(e)}")
|
769 |
|
770 |
|
771 |
@app.post("/separate", tags=["AI Editing"])
|
772 |
async def separate_sources(
|
773 |
background_tasks: BackgroundTasks,
|
774 |
file: UploadFile = File(..., description="Music audio file to separate into stems."),
|
775 |
+
model_key: Optional[str] = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use (defaults to primary)."),
|
776 |
stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
|
777 |
output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
|
778 |
):
|
779 |
"""Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
|
780 |
+
if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
|
781 |
+
actual_model_key = model_key or SEPARATION_MODEL_KEY
|
782 |
+
if actual_model_key not in separation_models:
|
783 |
+
logger.error(f"Separation model key '{actual_model_key}' requested but model not loaded.")
|
784 |
+
raise HTTPException(status_code=503, detail=f"Separation model '{actual_model_key}' is not loaded or available. Check server startup logs.")
|
785 |
|
786 |
+
loaded_model = separation_models[actual_model_key]
|
787 |
valid_stems = set(loaded_model.sources)
|
788 |
requested_stems = set(s.lower() for s in stems)
|
|
|
|
|
789 |
|
790 |
+
# Check if *any* requested stem is valid
|
791 |
+
if not requested_stems:
|
792 |
+
raise HTTPException(status_code=422, detail="No stems requested for separation.")
|
793 |
+
# Check if *all* requested stems are valid for this model
|
794 |
+
invalid_stems = requested_stems - valid_stems
|
795 |
+
if invalid_stems:
|
796 |
+
raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested: {', '.join(invalid_stems)}. Model '{actual_model_key}' provides: {', '.join(valid_stems)}")
|
797 |
+
|
798 |
+
logger.info(f"Separate request: file='{file.filename}', model='{actual_model_key}', stems={requested_stems}, format='{output_format}'")
|
799 |
input_path = await save_upload_file(file, prefix="separate_in_")
|
800 |
background_tasks.add_task(cleanup_file, input_path)
|
801 |
+
stem_output_paths: Dict[str, str] = {} # Store paths of successfully saved stems
|
802 |
+
zip_buffer = io.BytesIO(); zipf = None # Initialize zip buffer and file object
|
803 |
|
804 |
try:
|
805 |
+
# Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
|
806 |
audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)
|
807 |
+
|
808 |
logger.info("Submitting separation task to background thread...")
|
809 |
all_separated_stems_tensors = await asyncio.to_thread(
|
810 |
_run_separation_sync, loaded_model, audio_tensor, current_sr
|
811 |
)
|
812 |
+
logger.info("Separation task completed successfully.")
|
813 |
|
814 |
+
# --- Create ZIP file in memory ---
|
815 |
+
logger.info("Creating ZIP archive in memory...")
|
816 |
zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
|
817 |
+
files_added_to_zip = 0
|
818 |
for stem_name in requested_stems:
|
819 |
if stem_name in all_separated_stems_tensors:
|
820 |
stem_tensor = all_separated_stems_tensors[stem_name]
|
821 |
+
stem_path = None # Define stem_path before inner try
|
822 |
+
try:
|
823 |
+
# Save stem temporarily (save_hf_audio handles tensor)
|
824 |
+
# Use the model's native sampling rate for output (DEMUCS_SR)
|
825 |
+
stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
|
826 |
+
stem_output_paths[stem_name] = stem_path
|
827 |
+
# Schedule cleanup AFTER zip is potentially sent
|
828 |
+
background_tasks.add_task(cleanup_file, stem_path)
|
829 |
+
|
830 |
+
# Use a simpler archive name within the zip
|
831 |
+
archive_name = f"{stem_name}.{output_format}"
|
832 |
+
zipf.write(stem_path, arcname=archive_name)
|
833 |
+
files_added_to_zip += 1
|
834 |
+
logger.info(f"Added '{archive_name}' to ZIP.")
|
835 |
+
except Exception as save_err:
|
836 |
+
# Log error saving/zipping this stem but continue with others
|
837 |
+
logger.error(f"Failed to save or add stem '{stem_name}' to zip: {save_err}", exc_info=True)
|
838 |
+
if stem_path: cleanup_file(stem_path) # Clean up if saved but couldn't zip
|
839 |
+
else:
|
840 |
+
# This case should be prevented by the earlier validation
|
841 |
+
logger.warning(f"Requested stem '{stem_name}' not found in model output (validation error?).")
|
842 |
|
843 |
zipf.close() # Close zip file BEFORE seeking/reading
|
844 |
+
zipf = None # Clear variable to indicate closed
|
845 |
|
846 |
+
if files_added_to_zip == 0:
|
847 |
+
logger.error("Failed to add any requested stems to the ZIP archive.")
|
848 |
+
raise HTTPException(status_code=500, detail="Failed to generate any of the requested stems.")
|
849 |
+
|
850 |
+
zip_buffer.seek(0) # Rewind buffer pointer for reading
|
851 |
+
|
852 |
+
# Create final ZIP filename
|
853 |
+
zip_filename = f"separated_{actual_model_key}_{os.path.splitext(file.filename)[0]}.zip"
|
854 |
+
logger.info(f"Sending ZIP file: {zip_filename}")
|
855 |
return StreamingResponse(
|
856 |
+
iter([zip_buffer.getvalue()]), # StreamingResponse needs an iterator
|
857 |
media_type="application/zip",
|
858 |
headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
|
859 |
)
|
860 |
+
except HTTPException as http_exc:
|
861 |
+
logger.error(f"HTTP Exception during separation: {http_exc.detail}")
|
|
|
862 |
if zipf: zipf.close() # Ensure zipfile is closed
|
863 |
if zip_buffer: zip_buffer.close()
|
864 |
+
for path in stem_output_paths.values(): cleanup_file(path) # Cleanup successful stems
|
865 |
+
raise http_exc
|
866 |
+
except Exception as e:
|
867 |
+
logger.error(f"Unexpected error during separation operation: {e}", exc_info=True)
|
868 |
+
if zipf: zipf.close()
|
869 |
+
if zip_buffer: zip_buffer.close()
|
870 |
for path in stem_output_paths.values(): cleanup_file(path)
|
871 |
+
raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during separation: {str(e)}")
|
872 |
+
finally:
|
873 |
+
# Ensure buffer is closed if not already done
|
874 |
+
if zip_buffer and not zip_buffer.closed:
|
875 |
+
zip_buffer.close()
|
876 |
+
|
877 |
+
|
878 |
+
# --- How to Run ---
|
879 |
+
# 1. Ensure FFmpeg is installed and accessible in your PATH.
|
880 |
+
# 2. Save this code as `app.py`.
|
881 |
+
# 3. Create `requirements.txt` (including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs, python-multipart, protobuf).
|
882 |
+
# 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!).
|
883 |
+
# 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Use port 7860 for HF Spaces default, remove --reload for production).
|
884 |
+
#
|
885 |
+
# --- WARNING ---
|
886 |
+
# - AI models require SIGNIFICANT RAM (often 8GB+) and CPU/GPU. Inference can be SLOW (minutes). Free HF Spaces might time out or lack resources.
|
887 |
+
# - First run downloads models (can take a long time/lots of disk space).
|
888 |
+
# - Ensure model names (e.g., "htdemucs") are correct.
|
889 |
+
# - MONITOR STARTUP LOGS carefully for model loading success/failure. Errors here will cause 503 errors later.
|