File size: 23,239 Bytes
526b24d
 
 
 
 
 
 
 
 
3f784c4
 
526b24d
 
 
 
 
 
 
 
3f784c4
 
526b24d
 
3f784c4
 
 
 
 
 
 
526b24d
 
3f784c4
 
526b24d
 
 
 
 
3f784c4
 
3ef3c9e
526b24d
 
 
 
 
 
 
 
3f784c4
 
 
 
526b24d
3f784c4
 
3ef3c9e
3f784c4
 
 
3ef3c9e
3f784c4
 
 
 
 
 
 
 
526b24d
 
 
 
 
 
 
 
 
 
3f784c4
526b24d
 
 
 
 
 
 
 
 
 
 
 
 
3f784c4
 
 
526b24d
 
 
 
3f784c4
 
 
 
 
3ef3c9e
3f784c4
 
526b24d
3f784c4
526b24d
3f784c4
526b24d
3f784c4
 
 
 
526b24d
3f784c4
526b24d
 
 
3f784c4
 
526b24d
 
 
 
 
3f784c4
 
526b24d
 
 
3f784c4
 
 
 
 
 
 
 
 
 
 
3ef3c9e
3f784c4
 
 
 
526b24d
3f784c4
 
3ef3c9e
3f784c4
526b24d
3f784c4
526b24d
3f784c4
 
3ef3c9e
3f784c4
 
 
 
526b24d
 
 
 
3f784c4
526b24d
 
 
 
 
 
 
 
 
3f784c4
3ef3c9e
3f784c4
 
 
526b24d
3f784c4
 
 
 
 
 
 
 
 
3ef3c9e
3f784c4
 
 
3ef3c9e
3f784c4
 
526b24d
 
 
3f784c4
 
526b24d
3f784c4
 
 
 
 
 
3ef3c9e
3f784c4
 
 
 
3ef3c9e
3f784c4
 
 
 
 
 
3ef3c9e
3f784c4
 
 
 
 
 
 
3ef3c9e
526b24d
3f784c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526b24d
 
3f784c4
526b24d
 
 
 
3f784c4
3ef3c9e
3f784c4
 
 
3ef3c9e
3f784c4
 
526b24d
3f784c4
 
 
 
 
 
 
 
 
 
 
526b24d
3f784c4
526b24d
 
3f784c4
 
526b24d
3f784c4
 
 
 
 
526b24d
3f784c4
 
526b24d
 
3f784c4
526b24d
 
3f784c4
 
526b24d
 
 
 
3f784c4
 
526b24d
3f784c4
526b24d
3ef3c9e
3f784c4
526b24d
 
3f784c4
526b24d
 
3f784c4
 
526b24d
 
 
 
3f784c4
 
526b24d
 
3ef3c9e
3f784c4
 
 
 
3ef3c9e
3f784c4
3ef3c9e
526b24d
 
 
 
3f784c4
 
 
526b24d
3f784c4
 
 
 
3ef3c9e
 
526b24d
3f784c4
 
 
 
 
3ef3c9e
3f784c4
526b24d
3f784c4
 
 
 
526b24d
3f784c4
 
526b24d
 
 
3f784c4
 
526b24d
 
 
 
 
3f784c4
526b24d
 
 
3ef3c9e
526b24d
 
 
 
 
 
 
 
3f784c4
526b24d
 
 
3f784c4
 
 
 
3ef3c9e
 
526b24d
3f784c4
 
526b24d
 
3f784c4
526b24d
3f784c4
 
 
526b24d
3f784c4
3ef3c9e
526b24d
3f784c4
 
526b24d
 
3f784c4
 
526b24d
 
 
 
3f784c4
526b24d
3f784c4
526b24d
3f784c4
 
 
 
 
 
 
 
 
526b24d
 
 
3f784c4
 
3ef3c9e
3f784c4
526b24d
3f784c4
526b24d
3f784c4
526b24d
3f784c4
526b24d
 
 
3ef3c9e
3f784c4
526b24d
 
 
3f784c4
 
526b24d
3f784c4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
import os
import uuid
import tempfile
import logging
import asyncio
from typing import List, Optional, Dict, Any

from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
import io
import zipfile

# --- Basic Editing Imports ---
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError

# --- AI & Advanced Audio Imports ---
try:
    import torch
    # Transformers only needed if using HF pipelines directly, not for speechbrain/demucs manual loading
    # from transformers import pipeline
    import soundfile as sf
    import numpy as np
    import librosa

    # Specific Model Libraries
    import speechbrain.pretrained
    import demucs.separate
    import demucs.apply

    print("AI and advanced audio libraries loaded.")
except ImportError as e:
    print(f"Error importing AI/Audio libraries: {e}")
    print("Ensure torch, soundfile, librosa, speechbrain, demucs are installed.")
    print("AI features will be unavailable.")
    torch = None
    sf = None
    np = None
    librosa = None
    speechbrain = None
    demucs = None

# --- Configuration & Setup ---
TEMP_DIR = tempfile.gettempdir()
os.makedirs(TEMP_DIR, exist_ok=True)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Global Variables for Loaded Models ---
# Use consistent keys for storing/retrieving models
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
# Choose a default Demucs model (htdemucs is good quality)
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one

enhancement_models: Dict[str, Any] = {}
separation_models: Dict[str, Any] = {}

# Target sampling rates (confirm from model specifics if necessary)
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
DEMUCS_SR = 44100      # Demucs default is 44.1kHz

# --- Device Selection ---
if torch:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {DEVICE}")
else:
    DEVICE = "cpu" # Fallback if torch failed import

# --- Helper Functions (cleanup_file, save_upload_file - same as before) ---
def cleanup_file(file_path: str):
    try:
        if file_path and os.path.exists(file_path):
            os.remove(file_path)
            logger.info(f"Cleaned up temporary file: {file_path}")
    except Exception as e:
        logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)

async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
    _, file_extension = os.path.splitext(upload_file.filename)
    if not file_extension: file_extension = ".wav"
    temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")
    try:
        with open(temp_file_path, "wb") as buffer:
             while content := await upload_file.read(1024 * 1024): buffer.write(content)
        logger.info(f"Saved uploaded file '{upload_file.filename}' to temp path: {temp_file_path}")
        return temp_file_path
    except Exception as e:
        logger.error(f"Failed to save uploaded file {upload_file.filename}: {e}", exc_info=True)
        cleanup_file(temp_file_path)
        raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
    finally:
         await upload_file.close()

# --- Audio Loading/Saving for AI Models ---
def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
    """Loads audio, converts to mono float32 Torch tensor, optionally resamples."""
    try:
        audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
        logger.info(f"Loaded audio '{os.path.basename(file_path)}' with SR={orig_sr}, shape={audio.shape}, dtype={audio.dtype}")

        if audio.ndim > 1 and audio.shape[0] > 5: # Check if likely stereo (more than 5 channels unlikely)
             logger.warning(f"Detected {audio.shape[0]} channels, attempting to convert to mono by averaging.")
             audio = np.mean(audio, axis=0) # Average channels if multi-channel
        elif audio.ndim > 1:
             audio = audio[0] # Take first channel if shape is like (1, N)

        # Convert numpy array to torch tensor
        audio_tensor = torch.from_numpy(audio).float()

        # Resample if necessary using librosa then convert back to tensor
        if target_sr and orig_sr != target_sr:
            if librosa is None: raise RuntimeError("Librosa is required for resampling but not installed.")
            logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz...")
            # Librosa works on numpy, so convert back temp.
            audio_np = audio_tensor.numpy()
            resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr)
            audio_tensor = torch.from_numpy(resampled_audio_np).float()
            current_sr = target_sr
            logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")
        else:
            current_sr = orig_sr

        # Ensure tensor is on the correct device
        return audio_tensor.to(DEVICE), current_sr

    except Exception as e:
        logger.error(f"Error loading/processing audio file {file_path} for HF: {e}", exc_info=True)
        raise HTTPException(status_code=415, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Ensure it's a valid audio format.")

def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
    """Saves audio data (Tensor or NumPy array) to a temporary file."""
    output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format}"
    output_path = os.path.join(TEMP_DIR, output_filename)
    try:
        logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")

        # Convert tensor to numpy array if needed
        if isinstance(audio_data, torch.Tensor):
            logger.debug("Converting output tensor to NumPy array.")
            # Ensure tensor is on CPU before converting to numpy
            audio_np = audio_data.detach().cpu().numpy()
        elif isinstance(audio_data, np.ndarray):
            audio_np = audio_data
        else:
            raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")

        # Ensure data is float32
        if audio_np.dtype != np.float32:
             logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
             audio_np = audio_np.astype(np.float32)

        # Clip values to avoid potential issues with formats expecting [-1, 1]
        audio_np = np.clip(audio_np, -1.0, 1.0)

        # Use soundfile (preferred for wav/flac)
        if output_format.lower() in ['wav', 'flac']:
             sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
        else:
             # For lossy formats, use pydub
             logger.debug(f"Using pydub to export to lossy format: {output_format}")
             # Scale float32 [-1, 1] to int16 for pydub
             audio_int16 = (audio_np * 32767).astype(np.int16)
             # Create AudioSegment (assuming mono for now)
             num_channels = 1 if audio_int16.ndim == 1 else audio_int16.shape[0] # Basic channel check
             if num_channels > 1 : audio_int16=audio_int16[0] # Use first channel if > 1, needs better handling
             segment = AudioSegment(
                 audio_int16.tobytes(),
                 frame_rate=sampling_rate,
                 sample_width=audio_int16.dtype.itemsize,
                 channels=1 # Forcing mono currently
             )
             segment.export(output_path, format=output_format)

        return output_path
    except Exception as e:
        logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
        cleanup_file(output_path)
        raise HTTPException(status_code=500, detail="Failed to save processed audio.")

# --- Synchronous AI Inference Functions ---

def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
    """Synchronous wrapper for SpeechBrain enhancement model inference."""
    if not model: raise ValueError("Enhancement model not loaded")
    try:
        logger.info(f"Running speech enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")
        # SpeechBrain models usually take tensors directly
        # Add batch dimension if needed (most SB models expect batch)
        if audio_tensor.ndim == 1:
            audio_tensor = audio_tensor.unsqueeze(0)

        # Move tensor to the same device as the model
        model_device = next(model.parameters()).device
        audio_tensor = audio_tensor.to(model_device)

        with torch.no_grad():
            # Use enhance_batch for batched input
            enhanced_tensor = model.enhance_batch(audio_tensor, lengths=torch.tensor([audio_tensor.shape[1]]).to(model_device))

        # Remove batch dimension from output before returning
        enhanced_audio = enhanced_tensor.squeeze(0).cpu() # Move back to CPU
        logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
        return enhanced_audio
    except Exception as e:
        logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
        raise

def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
    """Synchronous wrapper for Demucs source separation model inference."""
    if not model: raise ValueError("Separation model not loaded")
    if not demucs: raise RuntimeError("Demucs library not available")
    try:
        logger.info(f"Running source separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {audio_tensor.device})...")

        # Demucs expects audio as (batch, channels, samples)
        # Ensure input tensor is on the correct device
        model_device = next(model.parameters()).device
        audio_tensor = audio_tensor.to(model_device)

        # Add batch and channel dimensions if mono
        if audio_tensor.ndim == 1:
            audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
        elif audio_tensor.ndim == 2: # Should not happen often if load_audio ensures mono tensor
             logger.warning("Input tensor has 2 dims, assuming (batch, samples), adding channel dim.")
             audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)

        # Ensure correct number of channels expected by the model (usually 2)
        if audio_tensor.shape[1] != model.audio_channels:
             logger.warning(f"Model expects {model.audio_channels} channels, input has {audio_tensor.shape[1]}. Repeating mono channel.")
             audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1) # Repeat mono to match expected channels


        logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")

        with torch.no_grad():
            # Use demucs.apply.apply_model which handles chunking etc.
            # requires ref = audio_tensor.mean(0) # Average channels for reference
            # sources = demucs.apply.apply_model(model, audio_tensor[0], device=model_device, shifts=1, split=True, overlap=0.25)[0] # Process first batch item

            # OR direct model call if simpler:
            sources = model(audio_tensor)[0] # Output shape (stems, channels, samples) - remove batch dim [0]

        logger.debug(f"Raw separated sources tensor shape: {sources.shape}") # Should be (num_stems, channels, samples)

        # Map stems based on the model's sources list
        # Default for htdemucs: drums, bass, other, vocals
        stem_map = {name: sources[i] for i, name in enumerate(model.sources)}

        # Convert back to mono for simplicity (average channels) and move to CPU
        output_stems = {}
        for name, data in stem_map.items():
             output_stems[name] = data.mean(dim=0).detach().cpu() # Average channels, detach, move to CPU

        logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
        return output_stems

    except Exception as e:
        logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
        raise

# --- Model Loading Function ---
def load_hf_models():
    """Loads AI models at startup using correct libraries."""
    global enhancement_models, separation_models
    if torch is None or speechbrain is None or demucs is None:
        logger.error("Core AI libraries (torch, speechbrain, demucs) not available. Skipping model loading.")
        return

    # --- Load Enhancement Model (SpeechBrain) ---
    enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
    try:
        logger.info(f"Loading enhancement model: {enhancement_model_hparams} (using SpeechBrain)...")
        # Ensure SpeechBrain downloads to a writable location if needed (optional)
        # savedir = os.path.join(TEMP_DIR, "speechbrain_models")
        # os.makedirs(savedir, exist_ok=True)
        enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
            source=enhancement_model_hparams,
            # savedir=savedir, # Specify download dir if needed
            run_opts={"device": DEVICE} # Pass device option
        )
        enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer # Store with consistent key
        logger.info(f"Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {DEVICE}.")
    except Exception as e:
        logger.error(f"Failed to load enhancement model '{enhancement_model_hparams}': {e}", exc_info=True)

    # --- Load Separation Model (Demucs) ---
    # Using a standard pretrained model name from the demucs package
    separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs" or "mdx_extra_q"
    try:
        logger.info(f"Loading separation model: {separation_model_name} (using Demucs package)...")
        # This automatically handles downloading the model checkpoint
        separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
        separation_models[SEPARATION_MODEL_KEY] = separator # Store with consistent key
        logger.info(f"Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {DEVICE}.")
    except Exception as e:
         logger.error(f"Failed to load separation model '{separation_model_name}': {e}", exc_info=True)
         logger.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs).")


# --- FastAPI App ---
app = FastAPI(
    title="AI Audio Editor API",
    description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries (torch, speechbrain, demucs).",
    version="2.1.0", # Incremented version
)

@app.on_event("startup")
async def startup_event():
    logger.info("Application startup: Loading AI models...")
    # Run blocking model load in thread
    await asyncio.to_thread(load_hf_models)
    logger.info("Model loading process finished (check logs for success/failure).")

# --- API Endpoints ---

@app.get("/", tags=["General"])
def read_root():
    # ... (root endpoint remains the same) ...
    features = ["/trim", "/concat", "/volume", "/convert"]
    ai_features = []
    if enhancement_models: ai_features.append(f"/enhance (model: {ENHANCEMENT_MODEL_KEY})")
    if separation_models: ai_features.append(f"/separate (model: {SEPARATION_MODEL_KEY})")

    return {
        "message": "Welcome to the AI Audio Editor API.",
        "basic_features": features,
        "ai_features": ai_features if ai_features else "None available (check startup logs)",
        "notes": "Requires FFmpeg. AI features require specific models loaded at startup."
        }


# --- Basic Editing Endpoints ---
# (Add /trim, /concat, /volume, /convert endpoints here - same logic as before)
# Make sure they use the updated cleanup_file and save_upload_file helpers.
# ...

# --- AI Endpoints (Corrected) ---

@app.post("/enhance", tags=["AI Editing"])
async def enhance_speech(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
    # Model ID is less relevant now if only one is loaded, but keep for future flexibility
    model_key: str = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use."),
    output_format: str = Form("wav", description="Output format (wav, flac recommended).")
):
    """Enhances speech audio using a pre-loaded SpeechBrain model."""
    if torch is None or speechbrain is None:
         raise HTTPException(status_code=501, detail="AI processing libraries (torch, speechbrain) not available.")
    if model_key not in enhancement_models:
         logger.error(f"Enhancement model key '{model_key}' requested but model not loaded.")
         raise HTTPException(status_code=503, detail=f"Enhancement model '{model_key}' is not loaded or available. Check server logs.")

    loaded_model = enhancement_models[model_key]

    logger.info(f"Enhance request: file='{file.filename}', model='{model_key}', format='{output_format}'")
    input_path = await save_upload_file(file, prefix="enhance_in_")
    background_tasks.add_task(cleanup_file, input_path)
    output_path = None

    try:
        # Load audio as tensor, ensure correct SR
        # SpeechBrain Sepformer expects 16kHz
        audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)

        logger.info("Submitting enhancement task to background thread...")
        enhanced_audio_tensor = await asyncio.to_thread(
            _run_enhancement_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
        )
        logger.info("Enhancement task completed.")

        # Save the result (tensor output from enhancer)
        output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format) # Save at model's SR
        background_tasks.add_task(cleanup_file, output_path)

        return FileResponse(
            path=output_path,
            media_type=f"audio/{output_format}",
            filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
        )
    except Exception as e:
        logger.error(f"Error during enhancement operation: {e}", exc_info=True)
        if output_path: cleanup_file(output_path)
        if isinstance(e, HTTPException): raise e
        else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during enhancement: {str(e)}")


@app.post("/separate", tags=["AI Editing"])
async def separate_sources(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="Music audio file to separate into stems."),
    model_key: str = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use."),
    stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
    output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
):
    """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
    if torch is None or demucs is None:
         raise HTTPException(status_code=501, detail="AI processing libraries (torch, demucs) not available.")
    if model_key not in separation_models:
         logger.error(f"Separation model key '{model_key}' requested but model not loaded.")
         raise HTTPException(status_code=503, detail=f"Separation model '{model_key}' is not loaded or available. Check server logs.")

    loaded_model = separation_models[model_key]
    valid_stems = set(loaded_model.sources) # Get stems directly from loaded model
    requested_stems = set(s.lower() for s in stems)
    if not requested_stems.issubset(valid_stems):
        raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested. Model '{model_key}' provides: {', '.join(valid_stems)}")

    logger.info(f"Separate request: file='{file.filename}', model='{model_key}', stems={requested_stems}, format='{output_format}'")
    input_path = await save_upload_file(file, prefix="separate_in_")
    background_tasks.add_task(cleanup_file, input_path)
    stem_output_paths: Dict[str, str] = {}
    zip_buffer = None

    try:
        # Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
        audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)

        logger.info("Submitting separation task to background thread...")
        all_separated_stems_tensors = await asyncio.to_thread(
            _run_separation_sync, loaded_model, audio_tensor, current_sr # Pass SR even if unused by func now
        )
        logger.info("Separation task completed.")

        # --- Create ZIP file in memory ---
        zip_buffer = io.BytesIO()
        with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
            # Save only the requested stems
            for stem_name in requested_stems:
                if stem_name in all_separated_stems_tensors:
                    stem_tensor = all_separated_stems_tensors[stem_name]
                    # Save stem temporarily (save_hf_audio handles tensor)
                    # Use the model's native sampling rate for output
                    stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
                    stem_output_paths[stem_name] = stem_path
                    background_tasks.add_task(cleanup_file, stem_path)

                    archive_name = f"{stem_name}_{os.path.splitext(file.filename)[0]}.{output_format}"
                    zipf.write(stem_path, arcname=archive_name)
                    logger.info(f"Added '{archive_name}' to ZIP.")
                else:
                    # This case should be prevented by the earlier validation
                    logger.warning(f"Requested stem '{stem_name}' not found in model output (should not happen).")

        zip_buffer.seek(0)

        zip_filename = f"separated_{model_key}_{os.path.splitext(file.filename)[0]}.zip"
        return StreamingResponse(
            zip_buffer,
            media_type="application/zip",
            headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
        )
    except Exception as e:
        logger.error(f"Error during separation operation: {e}", exc_info=True)
        for path in stem_output_paths.values(): cleanup_file(path)
        if zip_buffer: zip_buffer.close()
        if isinstance(e, HTTPException): raise e
        else: raise HTTPException(status_code=500, detail=f"An unexpected error occurred during separation: {str(e)}")

# --- Add back the basic editing endpoints (/trim, /concat, /volume, /convert) here ---
# ... (Remember to include them) ...

# --- How to Run ---
# 1. Ensure FFmpeg is installed.
# 2. Save code as `app.py`. Create/update `requirements.txt`.
# 3. Install: `pip install -r requirements.txt` (May take significant time/space!)
# 4. Run: `uvicorn app:app --reload --host 0.0.0.0`