File size: 45,473 Bytes
2f8d75b
526b24d
 
 
 
 
 
3e135af
526b24d
 
 
3f784c4
 
526b24d
 
 
 
 
 
3e135af
 
 
 
 
 
 
 
2c84da8
 
 
3e135af
2c84da8
526b24d
3e135af
526b24d
3e135af
526b24d
3e135af
526b24d
3e135af
3f784c4
3e135af
3f784c4
3e135af
3f784c4
 
3e135af
 
526b24d
3e135af
2c84da8
3e135af
2c84da8
526b24d
 
 
 
3f784c4
 
3ef3c9e
526b24d
 
2c84da8
 
 
 
 
 
 
 
 
526b24d
3e135af
2c84da8
 
526b24d
 
3f784c4
2c84da8
 
526b24d
3f784c4
 
3ef3c9e
2c84da8
 
 
3ef3c9e
3f784c4
2c84da8
3f784c4
2c84da8
3f784c4
2c84da8
 
3f784c4
2f8d75b
2c84da8
 
526b24d
2f8d75b
526b24d
2c84da8
526b24d
3e135af
526b24d
2c84da8
526b24d
 
 
2f8d75b
2c84da8
 
 
526b24d
2c84da8
3f784c4
526b24d
2c84da8
526b24d
2c84da8
526b24d
2c84da8
 
 
 
526b24d
 
2c84da8
 
526b24d
 
2c84da8
 
 
 
 
 
 
 
526b24d
3f784c4
2c84da8
 
 
 
 
 
526b24d
 
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526b24d
2c84da8
 
 
 
3e135af
2c84da8
 
 
526b24d
2c84da8
3e135af
2c84da8
3f784c4
2c84da8
3f784c4
526b24d
2c84da8
 
 
3f784c4
2c84da8
 
 
 
 
526b24d
2c84da8
2f8d75b
2c84da8
3e135af
526b24d
3f784c4
 
2c84da8
 
 
2f8d75b
526b24d
 
2c84da8
 
 
3f784c4
2c84da8
 
3f784c4
 
 
 
2c84da8
 
 
 
 
 
3ef3c9e
2c84da8
3f784c4
3ef3c9e
2c84da8
 
 
 
 
 
 
 
 
 
 
 
526b24d
3f784c4
526b24d
2c84da8
 
 
 
 
 
 
 
 
 
 
 
526b24d
2c84da8
 
526b24d
 
 
2c84da8
 
526b24d
2c84da8
 
2f8d75b
2c84da8
 
 
2f8d75b
2c84da8
 
 
 
 
 
 
 
2f8d75b
2c84da8
 
 
 
 
 
 
 
2f8d75b
 
2c84da8
2f8d75b
 
 
2c84da8
2f8d75b
 
2c84da8
 
 
 
 
2f8d75b
2c84da8
2f8d75b
3f784c4
2c84da8
 
526b24d
3e135af
2c84da8
3e135af
2c84da8
3e135af
2c84da8
3f784c4
2c84da8
 
 
 
 
 
 
 
 
2f8d75b
526b24d
 
2c84da8
 
 
526b24d
3f784c4
2c84da8
 
3e135af
3f784c4
3e135af
3f784c4
3e135af
2c84da8
 
 
 
 
 
3f784c4
2c84da8
 
 
 
 
 
 
 
526b24d
2c84da8
 
3e135af
2c84da8
 
 
 
 
 
 
2f8d75b
2c84da8
 
 
 
 
 
 
 
3f784c4
2c84da8
 
 
 
526b24d
 
3e135af
526b24d
3f784c4
3e135af
 
2c84da8
 
3e135af
3ef3c9e
3e135af
 
3f784c4
3ef3c9e
2c84da8
 
3e135af
3f784c4
3e135af
526b24d
3e135af
2c84da8
 
 
3f784c4
 
2c84da8
2f8d75b
3f784c4
3e135af
2f8d75b
3e135af
2c84da8
526b24d
2c84da8
3e135af
 
 
526b24d
3e135af
2f8d75b
3e135af
526b24d
3e135af
2c84da8
3f784c4
3e135af
2f8d75b
3e135af
2c84da8
 
526b24d
3e135af
 
2c84da8
3e135af
 
 
2c84da8
 
526b24d
 
3f784c4
526b24d
 
3e135af
2c84da8
526b24d
 
 
 
3e135af
 
 
2c84da8
3e135af
 
2c84da8
3e135af
2c84da8
 
526b24d
3ef3c9e
3f784c4
526b24d
 
2c84da8
526b24d
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526b24d
 
 
3e135af
2c84da8
 
 
526b24d
 
2f8d75b
3e135af
2c84da8
2f8d75b
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
2f8d75b
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
 
2f8d75b
2c84da8
 
2f8d75b
 
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
 
2f8d75b
2c84da8
 
2f8d75b
 
2c84da8
 
 
 
 
 
 
 
 
 
2f8d75b
 
2c84da8
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
2f8d75b
2c84da8
 
2f8d75b
 
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
2f8d75b
2c84da8
 
 
 
2f8d75b
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f8d75b
2c84da8
2f8d75b
2c84da8
2f8d75b
3ef3c9e
2c84da8
3ef3c9e
526b24d
 
 
 
2c84da8
 
3f784c4
526b24d
3f784c4
2c84da8
 
 
 
 
 
 
 
526b24d
2c84da8
3f784c4
 
3ef3c9e
526b24d
2c84da8
3f784c4
2c84da8
526b24d
3f784c4
2f8d75b
526b24d
 
2c84da8
 
2f8d75b
526b24d
2c84da8
2f8d75b
2c84da8
 
 
 
 
 
 
526b24d
2c84da8
3e135af
2c84da8
526b24d
 
 
 
 
 
2c84da8
526b24d
 
 
3f784c4
2c84da8
 
 
 
 
526b24d
2c84da8
3e135af
526b24d
 
2c84da8
 
 
 
 
 
 
 
 
3f784c4
 
2c84da8
 
3ef3c9e
526b24d
2c84da8
3f784c4
2c84da8
526b24d
3f784c4
2f8d75b
526b24d
2c84da8
526b24d
2c84da8
 
3e135af
2c84da8
3e135af
 
 
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e135af
 
2c84da8
526b24d
2c84da8
 
 
 
 
 
 
 
 
526b24d
2c84da8
526b24d
3f784c4
526b24d
2c84da8
 
3e135af
 
2c84da8
 
 
 
 
 
3ef3c9e
2c84da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
# app.py
import os
import uuid
import tempfile
import logging
import asyncio
from typing import List, Optional, Dict, Any
import traceback # For detailed error logging

from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, Query
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
import io
import zipfile

# --- Basic Editing Imports ---
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError

# --- AI & Advanced Audio Imports ---
# Add extra logging around imports
logger_init = logging.getLogger("AppInit")
logger_init.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Create console handler and set level to info
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
# Avoid adding handler multiple times if script reloads
if not logger_init.handlers:
    logger_init.addHandler(ch)

AI_LIBS_AVAILABLE = False
try:
    logger_init.info("Importing torch...")
    import torch
    logger_init.info("Importing soundfile...")
    import soundfile as sf
    logger_init.info("Importing numpy...")
    import numpy as np
    logger_init.info("Importing librosa...")
    import librosa
    logger_init.info("Importing speechbrain...")
    import speechbrain.pretrained
    logger_init.info("Importing demucs...")
    import demucs.separate
    import demucs.apply
    logger_init.info("AI and advanced audio libraries imported successfully.")
    AI_LIBS_AVAILABLE = True
except ImportError as e:
    logger_init.error(f"CRITICAL: Error importing AI/Audio libraries: {e}", exc_info=True)
    logger_init.error("Ensure torch, soundfile, librosa, speechbrain, demucs are in requirements.txt and installed correctly.")
    logger_init.error("AI features will be unavailable.")
    # Define placeholders so the rest of the code doesn't break completely on import error
    torch = None
    sf = None
    np = None
    librosa = None
    speechbrain = None
    demucs = None

# --- Configuration & Setup ---
TEMP_DIR = tempfile.gettempdir()
# Attempt to create temp dir if it doesn't exist (useful in some environments)
try:
    os.makedirs(TEMP_DIR, exist_ok=True)
except OSError as e:
    logger_init.error(f"Could not create temporary directory {TEMP_DIR}: {e}")
    # Fallback or raise an error depending on desired behavior
    TEMP_DIR = "." # Use current directory as fallback (less ideal)
    logger_init.warning(f"Using current directory '{TEMP_DIR}' for temporary files.")


# Configure main app logging (use the root logger setup by FastAPI/Uvicorn)
# This logger will be used by endpoint handlers
logger = logging.getLogger(__name__)

# --- Global Variables for Loaded Models ---
ENHANCEMENT_MODEL_KEY = "speechbrain_sepformer"
# Choose a default Demucs model (htdemucs is good quality)
SEPARATION_MODEL_KEY = "htdemucs" # Or use "mdx_extra_q" for a faster quantized one

enhancement_models: Dict[str, Any] = {}
separation_models: Dict[str, Any] = {}

# Target sampling rates (confirm from model specifics if necessary)
ENHANCEMENT_SR = 16000 # Sepformer WHAMR operates at 16kHz
DEMUCS_SR = 44100      # Demucs default is 44.1kHz

# --- Device Selection ---
if AI_LIBS_AVAILABLE and torch:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    logger_init.info(f"Selected device for AI models: {DEVICE}")
else:
    DEVICE = "cpu" # Fallback if torch failed import
    logger_init.info("Torch not available or AI libs failed import, defaulting device to CPU.")


# --- Helper Functions ---

def cleanup_file(file_path: str):
    """Safely remove a file."""
    try:
        if file_path and isinstance(file_path, str) and os.path.exists(file_path):
            os.remove(file_path)
            # logger.info(f"Cleaned up temporary file: {file_path}") # Reduce log noise
    except Exception as e:
        # Log error but don't crash the cleanup process for other files
        logger.error(f"Error cleaning up file {file_path}: {e}", exc_info=False)

async def save_upload_file(upload_file: UploadFile, prefix: str = "upload_") -> str:
    """Saves an uploaded file to a temporary location and returns the path."""
    if not upload_file or not upload_file.filename:
         raise HTTPException(status_code=400, detail="Invalid file upload object.")

    _, file_extension = os.path.splitext(upload_file.filename)
    # Default to .wav if no extension, as it's widely compatible for loading
    if not file_extension: file_extension = ".wav"
    temp_file_path = os.path.join(TEMP_DIR, f"{prefix}{uuid.uuid4().hex}{file_extension}")

    try:
        logger.debug(f"Attempting to save uploaded file to: {temp_file_path}")
        with open(temp_file_path, "wb") as buffer:
             # Read chunk by chunk for large files
             while content := await upload_file.read(1024 * 1024): # 1MB chunks
                 buffer.write(content)
        logger.info(f"Saved uploaded file '{upload_file.filename}' ({upload_file.content_type}) to temp path: {temp_file_path}")
        return temp_file_path
    except Exception as e:
        logger.error(f"Failed to save uploaded file '{upload_file.filename}' to {temp_file_path}: {e}", exc_info=True)
        cleanup_file(temp_file_path) # Attempt cleanup if saving failed
        raise HTTPException(status_code=500, detail=f"Could not save uploaded file: {upload_file.filename}")
    finally:
         # Ensure file is closed even if saving fails mid-way
         try:
            await upload_file.close()
         except Exception:
             pass # Ignore errors during close if already failed


# --- Audio Loading/Saving Functions ---

def load_audio_for_hf(file_path: str, target_sr: Optional[int] = None) -> tuple[torch.Tensor, int]:
    """Loads audio using soundfile, converts to mono float32 Torch tensor, optionally resamples."""
    if not AI_LIBS_AVAILABLE:
        raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")
    if not os.path.exists(file_path):
         raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found at {file_path}")

    try:
        audio, orig_sr = sf.read(file_path, dtype='float32', always_2d=False)
        logger.info(f"Loaded '{os.path.basename(file_path)}' - SR={orig_sr}, Shape={audio.shape}, dtype={audio.dtype}")

        # Ensure mono
        if audio.ndim > 1:
            # Check which dimension is smaller (likely channels)
            channel_dim = np.argmin(audio.shape)
            if audio.shape[channel_dim] > 1 and audio.shape[channel_dim] < 10: # Heuristic: <10 channels
                 logger.info(f"Detected {audio.shape[channel_dim]} channels. Converting to mono by averaging axis {channel_dim}.")
                 audio = np.mean(audio, axis=channel_dim)
            else: # Fallback or if shape is ambiguous (e.g., very short stereo)
                 logger.warning(f"Audio has shape {audio.shape}. Taking first channel/element assuming mono or channel-first.")
                 audio = audio[0] if channel_dim == 0 else audio[:, 0] # Select first index of the likely channel dimension

            logger.debug(f"Shape after mono conversion: {audio.shape}")


        # Ensure it's now 1D
        audio = audio.flatten()

        # Convert numpy array to torch tensor
        audio_tensor = torch.from_numpy(audio).float()

        # Resample if necessary using librosa
        current_sr = orig_sr
        if target_sr and orig_sr != target_sr:
            if librosa is None: raise RuntimeError("Librosa missing for resampling")
            logger.info(f"Resampling from {orig_sr} Hz to {target_sr} Hz for {os.path.basename(file_path)}...")
            # Librosa works on numpy
            audio_np = audio_tensor.numpy()
            resampled_audio_np = librosa.resample(audio_np, orig_sr=orig_sr, target_sr=target_sr, res_type='kaiser_best') # Specify resampling type
            audio_tensor = torch.from_numpy(resampled_audio_np).float()
            current_sr = target_sr
            logger.info(f"Resampled audio tensor shape: {audio_tensor.shape}")

        # Ensure tensor is on the correct device
        return audio_tensor.to(DEVICE), current_sr

    except sf.SoundFileError as sf_err:
         logger.error(f"SoundFileError loading {file_path}: {sf_err}", exc_info=True)
         cleanup_file(file_path)
         raise HTTPException(status_code=415, detail=f"Could not decode audio file: {os.path.basename(file_path)}. Unsupported format or corrupt file. Error: {sf_err}")
    except Exception as e:
        logger.error(f"Unexpected error loading/processing audio file {file_path} for AI: {e}", exc_info=True)
        cleanup_file(file_path)
        raise HTTPException(status_code=500, detail=f"Could not load or process audio file: {os.path.basename(file_path)}. Check server logs.")


def save_hf_audio(audio_data: Any, sampling_rate: int, output_format: str = "wav") -> str:
    """Saves audio data (Tensor or NumPy array) to a temporary file."""
    if not AI_LIBS_AVAILABLE:
         raise HTTPException(status_code=501, detail="AI Audio processing libraries not available.")

    output_filename = f"ai_output_{uuid.uuid4().hex}.{output_format.lower()}"
    output_path = os.path.join(TEMP_DIR, output_filename)
    try:
        logger.info(f"Saving AI processed audio to {output_path} (SR={sampling_rate}, format={output_format})")

        # Convert tensor to numpy array if needed
        if isinstance(audio_data, torch.Tensor):
            logger.debug("Converting output tensor to NumPy array.")
            # Ensure tensor is on CPU before converting to numpy
            audio_np = audio_data.detach().cpu().numpy()
        elif isinstance(audio_data, np.ndarray):
            audio_np = audio_data
        else:
            raise TypeError(f"Unsupported audio data type for saving: {type(audio_data)}")

        # Ensure data is float32
        if audio_np.dtype != np.float32:
             logger.warning(f"Output audio dtype is {audio_np.dtype}, converting to float32 for saving.")
             audio_np = audio_np.astype(np.float32)

        # Clip values to avoid potential issues with formats expecting [-1, 1]
        audio_np = np.clip(audio_np, -1.0, 1.0)

        # Ensure audio is 1D (mono) before saving with soundfile or pydub conversion
        if audio_np.ndim > 1:
            logger.warning(f"Output audio data has {audio_np.ndim} dimensions, attempting to flatten or take first dimension.")
            # Try averaging channels if shape suggests stereo/multi-channel
            channel_dim = np.argmin(audio_np.shape)
            if audio_np.shape[channel_dim] > 1 and audio_np.shape[channel_dim] < 10:
                 audio_np = np.mean(audio_np, axis=channel_dim)
            else: # Otherwise just flatten
                audio_np = audio_np.flatten()


        # Use soundfile (preferred for wav/flac)
        if output_format.lower() in ['wav', 'flac']:
             sf.write(output_path, audio_np, sampling_rate, format=output_format.upper())
        else:
             # For lossy formats, use pydub
             logger.debug(f"Using pydub to export to lossy format: {output_format}")
             # Scale float32 [-1, 1] to int16 for pydub
             audio_int16 = (audio_np * 32767).astype(np.int16)
             segment = AudioSegment(
                 audio_int16.tobytes(),
                 frame_rate=sampling_rate,
                 sample_width=audio_int16.dtype.itemsize,
                 channels=1 # Assuming mono after processing above
             )
             # Pydub might need explicit ffmpeg path in some envs
             # AudioSegment.converter = "/path/to/ffmpeg" # Uncomment and set path if needed
             segment.export(output_path, format=output_format)

        logger.info(f"Successfully saved AI audio to {output_path}")
        return output_path
    except Exception as e:
        logger.error(f"Error saving AI processed audio to {output_path}: {e}", exc_info=True)
        cleanup_file(output_path) # Attempt cleanup on saving failure
        raise HTTPException(status_code=500, detail=f"Failed to save processed audio to format '{output_format}'.")


# --- Pydub Loading/Exporting (for basic edits) ---
def load_audio_pydub(file_path: str) -> AudioSegment:
    """Loads an audio file using pydub."""
    if not os.path.exists(file_path):
         raise HTTPException(status_code=500, detail=f"Internal error: Input audio file not found (pydub) at {file_path}")
    try:
        logger.debug(f"Loading audio with pydub: {file_path}")
        # Explicitly provide format if possible, helps pydub sometimes
        file_ext = os.path.splitext(file_path)[1][1:].lower()
        if file_ext:
             audio = AudioSegment.from_file(file_path, format=file_ext)
        else:
             audio = AudioSegment.from_file(file_path) # Let pydub detect
        logger.info(f"Loaded audio using pydub from: {file_path}")
        return audio
    except CouldntDecodeError as e:
        logger.warning(f"Pydub CouldntDecodeError for {file_path}: {e}")
        cleanup_file(file_path)
        raise HTTPException(status_code=415, detail=f"Unsupported audio format or corrupted file (pydub): {os.path.basename(file_path)}")
    except Exception as e:
        logger.error(f"Error loading audio file {file_path} with pydub: {e}", exc_info=True)
        cleanup_file(file_path)
        raise HTTPException(status_code=500, detail=f"Error processing audio file (pydub): {os.path.basename(file_path)}")

def export_audio_pydub(audio: AudioSegment, format: str) -> str:
    """Exports a Pydub AudioSegment to a temporary file and returns the path."""
    output_filename = f"edited_{uuid.uuid4().hex}.{format.lower()}"
    output_path = os.path.join(TEMP_DIR, output_filename)
    try:
        logger.info(f"Exporting audio using pydub to format '{format}' at {output_path}")
        audio.export(output_path, format=format.lower())
        return output_path
    except Exception as e:
        logger.error(f"Error exporting audio with pydub to format {format}: {e}", exc_info=True)
        cleanup_file(output_path) # Cleanup if export failed
        raise HTTPException(status_code=500, detail=f"Failed to export audio to format '{format}' using pydub.")


# --- Synchronous AI Inference Functions ---

def _run_enhancement_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> torch.Tensor:
    """Synchronous wrapper for SpeechBrain enhancement model inference."""
    if not AI_LIBS_AVAILABLE or not model: raise ValueError("Enhancement model/libs not available")
    try:
        logger.info(f"Running enhancement (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
        model_device = next(model.parameters()).device # Check model's current device
        if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)
        # Add batch dimension if model expects it (most do)
        if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0)

        with torch.no_grad():
            # Check if model expects lengths parameter
            enhance_method = getattr(model, "enhance_batch", getattr(model, "forward", None))
            if "lengths" in enhance_method.__code__.co_varnames:
                 enhanced_tensor = enhance_method(audio_tensor, lengths=torch.tensor([audio_tensor.shape[-1]]).to(model_device))
            else:
                 enhanced_tensor = enhance_method(audio_tensor)


        # Remove batch dimension from output before returning, move back to CPU
        enhanced_audio = enhanced_tensor.squeeze(0).cpu()
        logger.info(f"Enhancement complete (output shape: {enhanced_audio.shape})")
        return enhanced_audio
    except Exception as e:
        logger.error(f"Error during synchronous enhancement inference: {e}", exc_info=True)
        raise # Re-raise to be caught by the async wrapper

def _run_separation_sync(model: Any, audio_tensor: torch.Tensor, sampling_rate: int) -> Dict[str, torch.Tensor]:
    """Synchronous wrapper for Demucs source separation model inference."""
    if not AI_LIBS_AVAILABLE or not model: raise ValueError("Separation model/libs not available")
    if not demucs: raise RuntimeError("Demucs library missing")
    try:
        logger.info(f"Running separation (input shape: {audio_tensor.shape}, SR: {sampling_rate}, Device: {DEVICE})...")
        model_device = next(model.parameters()).device
        if audio_tensor.device != model_device: audio_tensor = audio_tensor.to(model_device)

        # Demucs expects audio as (batch, channels, samples)
        if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, N)
        elif audio_tensor.ndim == 2: audio_tensor = audio_tensor.unsqueeze(1) # (B, 1, N)

        # Repeat channel if model expects stereo but input is mono
        if audio_tensor.shape[1] != model.audio_channels:
             if audio_tensor.shape[1] == 1:
                 logger.info(f"Model expects {model.audio_channels} channels, input is mono. Repeating channel.")
                 audio_tensor = audio_tensor.repeat(1, model.audio_channels, 1)
             else:
                  raise ValueError(f"Input channels ({audio_tensor.shape[1]}) mismatch model ({model.audio_channels})")

        logger.debug(f"Input tensor shape for Demucs: {audio_tensor.shape}")

        with torch.no_grad():
            # Use demucs.apply.apply_model for handling chunking etc.
            # Requires input shape (channels, samples) - process first batch item
            audio_to_process = audio_tensor.squeeze(0)
            # Note: shifts=1, split=True are common defaults for quality
            out = demucs.apply.apply_model(model, audio_to_process, device=model_device, shifts=1, split=True, overlap=0.25, progress=False) # Disable progress bar in logs
            # Output shape (stems, channels, samples)

        logger.debug(f"Raw separated sources tensor shape: {out.shape}")

        # Map stems based on the model's sources list
        stem_map = {name: out[i] for i, name in enumerate(model.sources)}

        # Convert back to mono for simplicity (average channels) and move to CPU
        output_stems = {}
        for name, data in stem_map.items():
             # Average channels, detach, move to CPU
             output_stems[name] = data.mean(dim=0).detach().cpu()

        logger.info(f"Separation complete. Found stems: {list(output_stems.keys())}")
        return output_stems

    except Exception as e:
        logger.error(f"Error during synchronous separation inference: {e}", exc_info=True)
        raise


# --- Model Loading Function (Enhanced Logging) ---
def load_hf_models():
    """Loads AI models at startup using correct libraries."""
    logger_load = logging.getLogger("ModelLoader") # Use specific logger
    logger_load.setLevel(logging.INFO)
    # Ensure handler is attached if logger is newly created
    if not logger_load.handlers and ch: logger_load.addHandler(ch)

    global enhancement_models, separation_models
    if not AI_LIBS_AVAILABLE:
        logger_load.error("Core AI libraries not available. Cannot load AI models.")
        return

    load_success_flags = {"enhancement": False, "separation": False}

    # --- Load Enhancement Model ---
    enhancement_model_hparams = "speechbrain/sepformer-whamr-enhancement"
    logger_load.info(f"--- Attempting to load Enhancement Model: {enhancement_model_hparams} ---")
    try:
        logger_load.info(f"Attempting load on device: {DEVICE}")
        # Consider adding savedir if cache issues arise in HF Spaces
        # savedir_sb = os.path.join(TEMP_DIR, "speechbrain_models")
        # os.makedirs(savedir_sb, exist_ok=True)
        enhancer = speechbrain.pretrained.SepformerEnhancement.from_hparams(
            source=enhancement_model_hparams,
            # savedir=savedir_sb,
            run_opts={"device": DEVICE}
        )
        model_device = next(enhancer.parameters()).device
        enhancement_models[ENHANCEMENT_MODEL_KEY] = enhancer
        logger_load.info(f"SUCCESS: Enhancement model '{ENHANCEMENT_MODEL_KEY}' loaded successfully on {model_device}.")
        load_success_flags["enhancement"] = True
    except Exception as e:
        logger_load.error(f"FAILED to load enhancement model '{enhancement_model_hparams}'. Error:", exc_info=False)
        logger_load.error(f"Traceback: {traceback.format_exc()}") # Log full traceback separately
        logger_load.warning("Enhancement features will be unavailable.")


    # --- Load Separation Model ---
    separation_model_name = SEPARATION_MODEL_KEY # e.g., "htdemucs"
    logger_load.info(f"--- Attempting to load Separation Model: {separation_model_name} ---")
    try:
        logger_load.info(f"Attempting load on device: {DEVICE}")
        # This automatically handles downloading the model checkpoint via demucs package
        separator = demucs.apply.load_model(name=separation_model_name, device=DEVICE)
        model_device = next(separator.parameters()).device
        separation_models[SEPARATION_MODEL_KEY] = separator
        logger_load.info(f"SUCCESS: Separation model '{SEPARATION_MODEL_KEY}' loaded successfully on {model_device}.")
        logger_load.info(f"Separation model available sources: {separator.sources}")
        load_success_flags["separation"] = True
    except Exception as e:
         logger_load.error(f"FAILED to load separation model '{separation_model_name}'. Error:", exc_info=False)
         logger_load.error(f"Traceback: {traceback.format_exc()}")
         logger_load.warning("Ensure the 'demucs' package is installed correctly and the model name is valid (e.g., htdemucs). Check resource limits (RAM).")
         logger_load.warning("Separation features will be unavailable.")

    logger_load.info(f"--- Model loading attempts finished ---")
    logger_load.info(f"Enhancement Model Loaded: {load_success_flags['enhancement']}")
    logger_load.info(f"Separation Model Loaded: {load_success_flags['separation']}")


# --- FastAPI App ---
app = FastAPI(
    title="AI Audio Editor API",
    description="API for basic audio editing and AI-powered enhancement & separation. Requires FFmpeg and specific AI libraries.",
    version="2.1.2", # Incremented version
)

@app.on_event("startup")
async def startup_event():
    # Use the init logger for startup messages
    logger_init.info("--- FastAPI Application Startup ---")
    if AI_LIBS_AVAILABLE:
        logger_init.info("AI Libraries imported successfully. Loading models in background thread...")
        # Run blocking model load in thread
        await asyncio.to_thread(load_hf_models)
        logger_init.info("Background model loading task finished (check ModelLoader logs above for details).")
    else:
        logger_init.error("AI Libraries failed to import during init. AI features will be disabled.")
    logger_init.info("--- Startup sequence complete ---")

# --- API Endpoints ---

@app.get("/", tags=["General"])
def read_root():
    """Root endpoint providing a welcome message and status of loaded models."""
    features = ["/trim", "/concat", "/volume", "/convert"]
    ai_features_status = {}

    if AI_LIBS_AVAILABLE:
        if enhancement_models:
             ai_features_status[ENHANCEMENT_MODEL_KEY] = "Loaded"
        else:
             ai_features_status[ENHANCEMENT_MODEL_KEY] = "Failed to load (check startup logs)"

        if separation_models:
            model = separation_models.get(SEPARATION_MODEL_KEY)
            sources_str = ', '.join(model.sources) if model else 'N/A'
            ai_features_status[SEPARATION_MODEL_KEY] = f"Loaded (Sources: {sources_str})"
        else:
            ai_features_status[SEPARATION_MODEL_KEY] = "Failed to load (check startup logs)"
    else:
        ai_features_status["AI Status"] = "Libraries Failed Import"


    return {
        "message": "Welcome to the AI Audio Editor API.",
        "status": "AI Libraries Available" if AI_LIBS_AVAILABLE else "AI Libraries Import Failed",
        "ai_models_status": ai_features_status,
        "basic_endpoints": features,
        "notes": "Requires FFmpeg. AI features require successful model loading at startup."
        }


# --- Basic Editing Endpoints ---

@app.post("/trim", tags=["Basic Editing"])
async def trim_audio(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="Audio file to trim."),
    start_ms: int = Form(..., ge=0, description="Start time in milliseconds."),
    end_ms: int = Form(..., gt=0, description="End time in milliseconds.") # Ensure end > 0
):
    """Trims an audio file to the specified start and end times (in milliseconds). Uses Pydub."""
    if end_ms <= start_ms:
        raise HTTPException(status_code=422, detail="End time (end_ms) must be greater than start time (start_ms).")

    logger.info(f"Trim request: file='{file.filename}', start={start_ms}ms, end={end_ms}ms")
    input_path = await save_upload_file(file, prefix="trim_in_")
    # Schedule cleanup immediately after saving, even if loading fails later
    background_tasks.add_task(cleanup_file, input_path)
    output_path = None # Define before try block

    try:
        audio = load_audio_pydub(input_path) # Can raise HTTPException
        trimmed_audio = audio[start_ms:end_ms]
        logger.info(f"Audio trimmed to {len(trimmed_audio)}ms")

        # Determine original format for export
        original_format = os.path.splitext(file.filename)[1][1:].lower()
        # Use mp3 as default only if no extension or if it's 'tmp' etc.
        if not original_format or len(original_format) > 5: # Basic check for valid extension length
             original_format = "mp3"
             logger.warning(f"Using default export format 'mp3' for input '{file.filename}'")

        output_path = export_audio_pydub(trimmed_audio, original_format) # Can raise HTTPException
        background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup

        # Create a more informative filename
        output_filename=f"trimmed_{start_ms}-{end_ms}_{os.path.splitext(file.filename)[0]}.{original_format}"

        return FileResponse(
            path=output_path,
            media_type=f"audio/{original_format}", # Best guess for media type
            filename=output_filename
        )
    except HTTPException as http_exc:
        # If load/export raised HTTPException, re-raise it
        # Cleanup might have already been scheduled, background tasks handle errors
        logger.error(f"HTTP Exception during trim: {http_exc.detail}")
        if output_path: cleanup_file(output_path) # Try immediate cleanup if output exists
        raise http_exc
    except Exception as e:
        # Catch other unexpected errors during trimming logic
        logger.error(f"Unexpected error during trim operation: {e}", exc_info=True)
        if output_path: cleanup_file(output_path)
        raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during trimming: {str(e)}")


@app.post("/concat", tags=["Basic Editing"])
async def concatenate_audio(
    background_tasks: BackgroundTasks,
    files: List[UploadFile] = File(..., description="Two or more audio files to join in order."),
    output_format: str = Form("mp3", description="Desired output format (e.g., 'mp3', 'wav', 'ogg').")
):
    """Concatenates two or more audio files sequentially using Pydub."""
    if len(files) < 2:
        raise HTTPException(status_code=422, detail="Please upload at least two files to concatenate.")

    logger.info(f"Concatenate request: {len(files)} files, output_format='{output_format}'")
    input_paths = [] # Keep track of all saved input file paths
    output_path = None # Define before try block

    try:
        combined_audio: Optional[AudioSegment] = None
        for i, file in enumerate(files):
            if not file or not file.filename:
                logger.warning(f"Skipping invalid file upload at index {i}.")
                continue # Skip potentially empty file entries

            input_path = await save_upload_file(file, prefix=f"concat_{i}_in_")
            input_paths.append(input_path)
            # Schedule cleanup for this specific input file immediately
            background_tasks.add_task(cleanup_file, input_path)

            try:
                audio = load_audio_pydub(input_path)
                if combined_audio is None:
                    combined_audio = audio
                    logger.info(f"Starting concatenation with '{file.filename}' ({len(combined_audio)}ms)")
                else:
                    logger.info(f"Adding '{file.filename}' ({len(audio)}ms)")
                    combined_audio += audio
            except HTTPException as load_exc:
                # Log error but continue trying to load other files if possible
                logger.error(f"Failed to load file '{file.filename}' for concatenation: {load_exc.detail}. Skipping this file.")
            except Exception as load_exc:
                logger.error(f"Unexpected error loading file '{file.filename}' for concatenation: {load_exc}. Skipping this file.", exc_info=True)


        if combined_audio is None:
             raise HTTPException(status_code=400, detail="No valid audio files could be loaded and combined.")

        logger.info(f"Final concatenated audio length: {len(combined_audio)}ms")

        output_path = export_audio_pydub(combined_audio, output_format) # Can raise HTTPException
        background_tasks.add_task(cleanup_file, output_path) # Schedule output cleanup

        # Determine a reasonable output filename
        first_valid_filename = files[0].filename if files and files[0] else "audio"
        first_filename_base = os.path.splitext(first_valid_filename)[0]
        output_filename = f"concat_{first_filename_base}_and_{len(files)-1}_others.{output_format}"

        return FileResponse(
            path=output_path,
            media_type=f"audio/{output_format}",
            filename=output_filename
        )
    except HTTPException as http_exc:
        # If load/export raised HTTPException, re-raise it
        logger.error(f"HTTP Exception during concat: {http_exc.detail}")
        # Cleanup for output path, inputs are handled by background tasks
        if output_path: cleanup_file(output_path)
        raise http_exc
    except Exception as e:
        # Catch other unexpected errors during combining logic
        logger.error(f"Unexpected error during concat operation: {e}", exc_info=True)
        if output_path: cleanup_file(output_path)
        raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during concatenation: {str(e)}")


@app.post("/volume", tags=["Basic Editing"])
async def change_volume(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="Audio file to adjust volume for."),
    change_db: float = Form(..., description="Volume change in decibels (dB). Positive increases, negative decreases.")
):
    """Adjusts the volume of an audio file by a specified decibel amount using Pydub."""
    logger.info(f"Volume request: file='{file.filename}', change_db={change_db}dB")
    input_path = await save_upload_file(file, prefix="volume_in_")
    background_tasks.add_task(cleanup_file, input_path)
    output_path = None
    try:
        audio = load_audio_pydub(input_path)
        # Check for potential silence before applying gain
        if audio.dBFS == -float('inf'):
             logger.warning(f"Input file '{file.filename}' appears to be silent. Applying volume change may have no effect.")
        adjusted_audio = audio + change_db
        logger.info(f"Volume adjusted by {change_db}dB.")

        original_format = os.path.splitext(file.filename)[1][1:].lower()
        if not original_format or len(original_format) > 5: original_format = "mp3"

        output_path = export_audio_pydub(adjusted_audio, original_format)
        background_tasks.add_task(cleanup_file, output_path)

        # Create filename
        sign = "+" if change_db >= 0 else ""
        output_filename=f"volume_{sign}{change_db}dB_{os.path.splitext(file.filename)[0]}.{original_format}"

        return FileResponse(
            path=output_path,
            media_type=f"audio/{original_format}",
            filename=output_filename
        )
    except HTTPException as http_exc:
        logger.error(f"HTTP Exception during volume change: {http_exc.detail}")
        if output_path: cleanup_file(output_path)
        raise http_exc
    except Exception as e:
        logger.error(f"Unexpected error during volume operation: {e}", exc_info=True)
        if output_path: cleanup_file(output_path)
        raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during volume adjustment: {str(e)}")


@app.post("/convert", tags=["Basic Editing"])
async def convert_format(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="Audio file to convert."),
    output_format: str = Form(..., description="Target audio format (e.g., 'mp3', 'wav', 'ogg', 'flac').")
):
    """Converts an audio file to a different format using Pydub."""
    # Define allowed formats explicitly
    allowed_formats = {'mp3', 'wav', 'ogg', 'flac', 'aac', 'm4a', 'opus', 'wma', 'aiff'} # Expanded list
    output_format_lower = output_format.lower()
    if output_format_lower not in allowed_formats:
         raise HTTPException(status_code=422, detail=f"Invalid output format '{output_format}'. Allowed: {', '.join(sorted(list(allowed_formats)))}")

    logger.info(f"Convert request: file='{file.filename}', output_format='{output_format_lower}'")
    input_path = await save_upload_file(file, prefix="convert_in_")
    background_tasks.add_task(cleanup_file, input_path)
    output_path = None
    try:
        # Load using pydub, which handles many input formats
        audio = load_audio_pydub(input_path)
        logger.info(f"Successfully loaded '{file.filename}' for conversion.")

        # Export using pydub
        output_path = export_audio_pydub(audio, output_format_lower)
        background_tasks.add_task(cleanup_file, output_path)
        logger.info(f"Successfully exported to {output_format_lower}")

        # Construct new filename
        filename_base = os.path.splitext(file.filename)[0]
        output_filename = f"{filename_base}_converted.{output_format_lower}"

        # Determine media type (MIME type) - might need refinement for less common types
        media_type_map = {
            'mp3': 'audio/mpeg', 'wav': 'audio/wav', 'ogg': 'audio/ogg',
            'flac': 'audio/flac', 'aac': 'audio/aac', 'm4a': 'audio/mp4', # m4a often uses mp4 container
            'opus': 'audio/opus', 'wma':'audio/x-ms-wma', 'aiff':'audio/aiff'
        }
        media_type = media_type_map.get(output_format_lower, 'application/octet-stream') # Default binary if unknown

        return FileResponse(
            path=output_path,
            media_type=media_type,
            filename=output_filename
        )
    except HTTPException as http_exc:
        logger.error(f"HTTP Exception during conversion: {http_exc.detail}")
        if output_path: cleanup_file(output_path)
        raise http_exc
    except Exception as e:
        logger.error(f"Unexpected error during convert operation: {e}", exc_info=True)
        if output_path: cleanup_file(output_path)
        raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during format conversion: {str(e)}")


# --- AI Endpoints ---

@app.post("/enhance", tags=["AI Editing"])
async def enhance_speech(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="Noisy speech audio file to enhance."),
    # Keep model_key optional for now, assumes default if only one loaded
    model_key: Optional[str] = Form(ENHANCEMENT_MODEL_KEY, description="Internal key of the enhancement model to use (defaults to primary)."),
    output_format: str = Form("wav", description="Output format (wav, flac recommended).")
):
    """Enhances speech audio using a pre-loaded SpeechBrain model."""
    if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
    # Use the provided key or the default
    actual_model_key = model_key or ENHANCEMENT_MODEL_KEY
    if actual_model_key not in enhancement_models:
         logger.error(f"Enhancement model key '{actual_model_key}' requested but model not loaded.")
         raise HTTPException(status_code=503, detail=f"Enhancement model '{actual_model_key}' is not loaded or available. Check server startup logs.")

    loaded_model = enhancement_models[actual_model_key]

    logger.info(f"Enhance request: file='{file.filename}', model='{actual_model_key}', format='{output_format}'")
    input_path = await save_upload_file(file, prefix="enhance_in_")
    background_tasks.add_task(cleanup_file, input_path)
    output_path = None
    try:
        # Load audio as tensor, ensure correct SR (16kHz)
        audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=ENHANCEMENT_SR)

        logger.info("Submitting enhancement task to background thread...")
        enhanced_audio_tensor = await asyncio.to_thread(
            _run_enhancement_sync, loaded_model, audio_tensor, current_sr
        )
        logger.info("Enhancement task completed.")

        # Save the result (tensor output from enhancer at 16kHz)
        output_path = save_hf_audio(enhanced_audio_tensor, ENHANCEMENT_SR, output_format)
        background_tasks.add_task(cleanup_file, output_path)

        output_filename=f"enhanced_{os.path.splitext(file.filename)[0]}.{output_format}"
        media_type = f"audio/{output_format}" # Basic media type
        return FileResponse(path=output_path, media_type=media_type, filename=output_filename)

    except HTTPException as http_exc:
        logger.error(f"HTTP Exception during enhancement: {http_exc.detail}")
        if output_path: cleanup_file(output_path)
        raise http_exc
    except Exception as e:
        logger.error(f"Unexpected error during enhancement operation: {e}", exc_info=True)
        if output_path: cleanup_file(output_path)
        raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during enhancement: {str(e)}")


@app.post("/separate", tags=["AI Editing"])
async def separate_sources(
    background_tasks: BackgroundTasks,
    file: UploadFile = File(..., description="Music audio file to separate into stems."),
    model_key: Optional[str] = Form(SEPARATION_MODEL_KEY, description="Internal key of the separation model to use (defaults to primary)."),
    stems: List[str] = Form(..., description="List of stems to extract (e.g., 'vocals', 'drums', 'bass', 'other')."),
    output_format: str = Form("wav", description="Output format for the stems (wav, flac recommended).")
):
    """Separates music into stems using a pre-loaded Demucs model. Returns a ZIP archive."""
    if not AI_LIBS_AVAILABLE: raise HTTPException(status_code=501, detail="AI processing libraries not available.")
    actual_model_key = model_key or SEPARATION_MODEL_KEY
    if actual_model_key not in separation_models:
         logger.error(f"Separation model key '{actual_model_key}' requested but model not loaded.")
         raise HTTPException(status_code=503, detail=f"Separation model '{actual_model_key}' is not loaded or available. Check server startup logs.")

    loaded_model = separation_models[actual_model_key]
    valid_stems = set(loaded_model.sources)
    requested_stems = set(s.lower() for s in stems)

    # Check if *any* requested stem is valid
    if not requested_stems:
         raise HTTPException(status_code=422, detail="No stems requested for separation.")
    # Check if *all* requested stems are valid for this model
    invalid_stems = requested_stems - valid_stems
    if invalid_stems:
        raise HTTPException(status_code=422, detail=f"Invalid stem(s) requested: {', '.join(invalid_stems)}. Model '{actual_model_key}' provides: {', '.join(valid_stems)}")

    logger.info(f"Separate request: file='{file.filename}', model='{actual_model_key}', stems={requested_stems}, format='{output_format}'")
    input_path = await save_upload_file(file, prefix="separate_in_")
    background_tasks.add_task(cleanup_file, input_path)
    stem_output_paths: Dict[str, str] = {} # Store paths of successfully saved stems
    zip_buffer = io.BytesIO(); zipf = None # Initialize zip buffer and file object

    try:
        # Load audio as tensor, ensure correct SR (Demucs default 44.1kHz)
        audio_tensor, current_sr = load_audio_for_hf(input_path, target_sr=DEMUCS_SR)

        logger.info("Submitting separation task to background thread...")
        all_separated_stems_tensors = await asyncio.to_thread(
            _run_separation_sync, loaded_model, audio_tensor, current_sr
        )
        logger.info("Separation task completed successfully.")

        # --- Create ZIP file in memory ---
        logger.info("Creating ZIP archive in memory...")
        zipf = zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED)
        files_added_to_zip = 0
        for stem_name in requested_stems:
            if stem_name in all_separated_stems_tensors:
                stem_tensor = all_separated_stems_tensors[stem_name]
                stem_path = None # Define stem_path before inner try
                try:
                    # Save stem temporarily (save_hf_audio handles tensor)
                    # Use the model's native sampling rate for output (DEMUCS_SR)
                    stem_path = save_hf_audio(stem_tensor, DEMUCS_SR, output_format)
                    stem_output_paths[stem_name] = stem_path
                    # Schedule cleanup AFTER zip is potentially sent
                    background_tasks.add_task(cleanup_file, stem_path)

                    # Use a simpler archive name within the zip
                    archive_name = f"{stem_name}.{output_format}"
                    zipf.write(stem_path, arcname=archive_name)
                    files_added_to_zip += 1
                    logger.info(f"Added '{archive_name}' to ZIP.")
                except Exception as save_err:
                     # Log error saving/zipping this stem but continue with others
                     logger.error(f"Failed to save or add stem '{stem_name}' to zip: {save_err}", exc_info=True)
                     if stem_path: cleanup_file(stem_path) # Clean up if saved but couldn't zip
            else:
                # This case should be prevented by the earlier validation
                logger.warning(f"Requested stem '{stem_name}' not found in model output (validation error?).")

        zipf.close() # Close zip file BEFORE seeking/reading
        zipf = None # Clear variable to indicate closed

        if files_added_to_zip == 0:
            logger.error("Failed to add any requested stems to the ZIP archive.")
            raise HTTPException(status_code=500, detail="Failed to generate any of the requested stems.")

        zip_buffer.seek(0) # Rewind buffer pointer for reading

        # Create final ZIP filename
        zip_filename = f"separated_{actual_model_key}_{os.path.splitext(file.filename)[0]}.zip"
        logger.info(f"Sending ZIP file: {zip_filename}")
        return StreamingResponse(
            iter([zip_buffer.getvalue()]), # StreamingResponse needs an iterator
            media_type="application/zip",
            headers={'Content-Disposition': f'attachment; filename="{zip_filename}"'}
        )
    except HTTPException as http_exc:
        logger.error(f"HTTP Exception during separation: {http_exc.detail}")
        if zipf: zipf.close() # Ensure zipfile is closed
        if zip_buffer: zip_buffer.close()
        for path in stem_output_paths.values(): cleanup_file(path) # Cleanup successful stems
        raise http_exc
    except Exception as e:
        logger.error(f"Unexpected error during separation operation: {e}", exc_info=True)
        if zipf: zipf.close()
        if zip_buffer: zip_buffer.close()
        for path in stem_output_paths.values(): cleanup_file(path)
        raise HTTPException(status_code=500, detail=f"An unexpected server error occurred during separation: {str(e)}")
    finally:
         # Ensure buffer is closed if not already done
         if zip_buffer and not zip_buffer.closed:
             zip_buffer.close()


# --- How to Run ---
# 1. Ensure FFmpeg is installed and accessible in your PATH.
# 2. Save this code as `app.py`.
# 3. Create `requirements.txt` (including fastapi, uvicorn, pydub, torch, soundfile, librosa, speechbrain, demucs, python-multipart, protobuf).
# 4. Install dependencies: `pip install -r requirements.txt` (This can take significant time and disk space!).
# 5. Run the FastAPI server: `uvicorn app:app --host 0.0.0.0 --port 7860` (Use port 7860 for HF Spaces default, remove --reload for production).
#
# --- WARNING ---
# - AI models require SIGNIFICANT RAM (often 8GB+) and CPU/GPU. Inference can be SLOW (minutes). Free HF Spaces might time out or lack resources.
# - First run downloads models (can take a long time/lots of disk space).
# - Ensure model names (e.g., "htdemucs") are correct.
# - MONITOR STARTUP LOGS carefully for model loading success/failure. Errors here will cause 503 errors later.