ahk-d commited on
Commit
314aa29
Β·
verified Β·
1 Parent(s): f6f9768

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +397 -0
app.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import os
4
+ import subprocess
5
+ import shutil
6
+ from pathlib import Path
7
+ import logging
8
+ from typing import List, Tuple, Dict
9
+ import json
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class StemSeparator:
16
+ """Modern stem separation with multiple model support"""
17
+
18
+ def __init__(self):
19
+ self.supported_models = {
20
+ "htdemucs": {
21
+ "command": "demucs",
22
+ "stems": 4,
23
+ "description": "HTDemucs - High quality 4-stem separation"
24
+ },
25
+ "htdemucs_ft": {
26
+ "command": "demucs",
27
+ "model_name": "htdemucs_ft",
28
+ "stems": 4,
29
+ "description": "HTDemucs Fine-tuned - Enhanced 4-stem separation"
30
+ },
31
+ "htdemucs_6s": {
32
+ "command": "demucs",
33
+ "model_name": "htdemucs_6s",
34
+ "stems": 6,
35
+ "description": "HTDemucs 6-stem - Bass, Drums, Vocals, Other, Guitar, Piano"
36
+ },
37
+ "mdx": {
38
+ "command": "demucs",
39
+ "model_name": "mdx",
40
+ "stems": 4,
41
+ "description": "MDX - Optimized for vocal separation"
42
+ },
43
+ "mdx_extra": {
44
+ "command": "demucs",
45
+ "model_name": "mdx_extra",
46
+ "stems": 4,
47
+ "description": "MDX Extra - Enhanced vocal separation"
48
+ },
49
+ "spleeter_4stems": {
50
+ "command": "spleeter",
51
+ "model_name": "spleeter:4stems-waveform",
52
+ "stems": 4,
53
+ "description": "Spleeter 4-stem - Vocals, Bass, Drums, Other"
54
+ },
55
+ "spleeter_5stems": {
56
+ "command": "spleeter",
57
+ "model_name": "spleeter:5stems-waveform",
58
+ "stems": 5,
59
+ "description": "Spleeter 5-stem - Vocals, Bass, Drums, Piano, Other"
60
+ }
61
+ }
62
+
63
+ def check_dependencies(self) -> Dict[str, bool]:
64
+ """Check if required tools are installed"""
65
+ dependencies = {}
66
+
67
+ # Check demucs
68
+ try:
69
+ result = subprocess.run(["python", "-m", "demucs", "--help"],
70
+ capture_output=True, text=True, timeout=10)
71
+ dependencies["demucs"] = result.returncode == 0
72
+ except (subprocess.TimeoutExpired, FileNotFoundError):
73
+ dependencies["demucs"] = False
74
+
75
+ # Check spleeter
76
+ try:
77
+ result = subprocess.run(["spleeter", "--help"],
78
+ capture_output=True, text=True, timeout=10)
79
+ dependencies["spleeter"] = result.returncode == 0
80
+ except (subprocess.TimeoutExpired, FileNotFoundError):
81
+ dependencies["spleeter"] = False
82
+
83
+ return dependencies
84
+
85
+ def separate_audio(self, audio_file: str, model_choice: str) -> Tuple[List[str], str]:
86
+ """Separate audio into stems using the selected model"""
87
+ if not audio_file:
88
+ return [], "❌ No audio file provided"
89
+
90
+ if model_choice not in self.supported_models:
91
+ return [], f"❌ Unsupported model: {model_choice}"
92
+
93
+ model_config = self.supported_models[model_choice]
94
+
95
+ try:
96
+ with tempfile.TemporaryDirectory() as temp_dir:
97
+ temp_path = Path(temp_dir)
98
+
99
+ # Copy input file to temp directory with proper extension
100
+ input_file = Path(audio_file)
101
+ temp_input = temp_path / f"input{input_file.suffix}"
102
+ shutil.copy2(audio_file, temp_input)
103
+
104
+ logger.info(f"Processing {temp_input} with {model_choice}")
105
+
106
+ # Build command based on model type
107
+ if model_config["command"] == "demucs":
108
+ command = self._build_demucs_command(temp_input, temp_path, model_config)
109
+ elif model_config["command"] == "spleeter":
110
+ command = self._build_spleeter_command(temp_input, temp_path, model_config)
111
+ else:
112
+ return [], f"❌ Unknown command type: {model_config['command']}"
113
+
114
+ # Execute separation
115
+ logger.info(f"Running command: {' '.join(command)}")
116
+ result = subprocess.run(
117
+ command,
118
+ capture_output=True,
119
+ text=True,
120
+ timeout=300, # 5 minute timeout
121
+ cwd=temp_dir
122
+ )
123
+
124
+ if result.returncode != 0:
125
+ error_msg = f"❌ Separation failed: {result.stderr}"
126
+ logger.error(error_msg)
127
+ return [], error_msg
128
+
129
+ # Collect output stems
130
+ stems = self._collect_stems(temp_path, model_choice)
131
+
132
+ if not stems:
133
+ return [], "❌ No stems were generated"
134
+
135
+ success_msg = f"βœ… Successfully separated into {len(stems)} stems"
136
+ logger.info(success_msg)
137
+ return stems, success_msg
138
+
139
+ except subprocess.TimeoutExpired:
140
+ return [], "❌ Process timed out - file may be too large"
141
+ except Exception as e:
142
+ error_msg = f"❌ Error during separation: {str(e)}"
143
+ logger.error(error_msg)
144
+ return [], error_msg
145
+
146
+ def _build_demucs_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]:
147
+ """Build demucs command"""
148
+ command = ["python", "-m", "demucs"]
149
+
150
+ if "model_name" in model_config:
151
+ command.extend(["-n", model_config["model_name"]])
152
+
153
+ command.extend([
154
+ "-o", str(output_dir),
155
+ "--filename", "{track}/{stem}.{ext}", # Organized output structure
156
+ str(input_file)
157
+ ])
158
+
159
+ return command
160
+
161
+ def _build_spleeter_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]:
162
+ """Build spleeter command"""
163
+ model_name = model_config.get("model_name", "spleeter:4stems-waveform")
164
+
165
+ command = [
166
+ "spleeter", "separate",
167
+ "-p", model_name,
168
+ "-o", str(output_dir),
169
+ "--filename_format", "{instrument}.{codec}",
170
+ str(input_file)
171
+ ]
172
+
173
+ return command
174
+
175
+ def _collect_stems(self, output_dir: Path, model_choice: str) -> List[str]:
176
+ """Collect generated stem files"""
177
+ stems = []
178
+
179
+ # Search for audio files in output directory
180
+ for audio_file in output_dir.rglob("*.wav"):
181
+ if audio_file.is_file() and audio_file.stat().st_size > 0:
182
+ # Copy to a permanent location that Gradio can access
183
+ permanent_path = self._copy_to_permanent_location(audio_file)
184
+ if permanent_path:
185
+ stems.append(permanent_path)
186
+
187
+ # Also check for other common audio formats
188
+ for ext in ["*.mp3", "*.flac", "*.m4a"]:
189
+ for audio_file in output_dir.rglob(ext):
190
+ if audio_file.is_file() and audio_file.stat().st_size > 0:
191
+ permanent_path = self._copy_to_permanent_location(audio_file)
192
+ if permanent_path:
193
+ stems.append(permanent_path)
194
+
195
+ return sorted(stems)
196
+
197
+ def _copy_to_permanent_location(self, temp_file: Path) -> str:
198
+ """Copy temporary file to permanent location for Gradio"""
199
+ try:
200
+ # Create output directory if it doesn't exist
201
+ output_dir = Path("./separated_stems")
202
+ output_dir.mkdir(exist_ok=True)
203
+
204
+ # Generate unique filename
205
+ import time
206
+ timestamp = int(time.time() * 1000)
207
+ permanent_file = output_dir / f"{temp_file.stem}_{timestamp}{temp_file.suffix}"
208
+
209
+ shutil.copy2(temp_file, permanent_file)
210
+ return str(permanent_file)
211
+ except Exception as e:
212
+ logger.error(f"Failed to copy {temp_file}: {e}")
213
+ return None
214
+
215
+ # Initialize separator
216
+ separator = StemSeparator()
217
+
218
+ def get_available_models() -> List[Tuple[str, str]]:
219
+ """Get list of available models based on installed dependencies"""
220
+ deps = separator.check_dependencies()
221
+ available_models = []
222
+
223
+ for model_id, config in separator.supported_models.items():
224
+ if config["command"] in deps and deps[config["command"]]:
225
+ label = f"{model_id} ({config['stems']} stems) - {config['description']}"
226
+ available_models.append((label, model_id))
227
+
228
+ if not available_models:
229
+ available_models = [("No models available - install demucs or spleeter", "none")]
230
+
231
+ return available_models
232
+
233
+ def separate_stems_ui(audio_file: str, model_choice: str) -> Tuple[List[str], str]:
234
+ """UI wrapper for stem separation"""
235
+ if model_choice == "none":
236
+ return [], "❌ Please install demucs and/or spleeter first"
237
+
238
+ stems, message = separator.separate_audio(audio_file, model_choice)
239
+ return stems, message
240
+
241
+ def create_audio_gallery(stems: List[str]) -> List[gr.Audio]:
242
+ """Create audio components for each stem"""
243
+ if not stems:
244
+ return []
245
+
246
+ audio_components = []
247
+ for i, stem_path in enumerate(stems):
248
+ stem_name = Path(stem_path).stem
249
+ audio_comp = gr.Audio(
250
+ value=stem_path,
251
+ label=f"Stem {i+1}: {stem_name}",
252
+ interactive=False,
253
+ show_download_button=True
254
+ )
255
+ audio_components.append(audio_comp)
256
+
257
+ return audio_components
258
+
259
+ # Create Gradio interface
260
+ def create_interface():
261
+ with gr.Blocks(
262
+ title="🎡 Advanced Music Stem Separator",
263
+ theme=gr.themes.Soft(),
264
+ css="""
265
+ .audio-container { margin: 10px 0; }
266
+ .status-success { color: #22c55e; font-weight: bold; }
267
+ .status-error { color: #ef4444; font-weight: bold; }
268
+ """
269
+ ) as demo:
270
+
271
+ gr.Markdown("""
272
+ # 🎡 Advanced Music Stem Separator
273
+
274
+ Separate music into individual stems (vocals, instruments, etc.) using state-of-the-art AI models.
275
+ Supports up to 6 stems depending on the model chosen.
276
+
277
+ **Supported Models:**
278
+ - **Demucs Models**: HTDemucs, HTDemucs-FT, HTDemucs-6s, MDX, MDX-Extra
279
+ - **Spleeter Models**: 4-stem and 5-stem separation
280
+
281
+ **Requirements**: Install `demucs` and/or `spleeter` packages
282
+ """)
283
+
284
+ with gr.Row():
285
+ with gr.Column(scale=2):
286
+ audio_input = gr.Audio(
287
+ type="filepath",
288
+ label="🎼 Upload Audio File",
289
+ info="Supported formats: WAV, MP3, FLAC, M4A"
290
+ )
291
+
292
+ model_dropdown = gr.Dropdown(
293
+ choices=get_available_models(),
294
+ value=get_available_models()[0][1] if get_available_models() else "none",
295
+ label="🧠 Separation Model",
296
+ info="Choose the AI model for stem separation"
297
+ )
298
+
299
+ separate_btn = gr.Button(
300
+ "πŸŽ›οΈ Separate Stems",
301
+ variant="primary",
302
+ size="lg"
303
+ )
304
+
305
+ with gr.Column(scale=1):
306
+ gr.Markdown("""
307
+ ### ℹ️ Model Info
308
+ - **4-stem**: Vocals, Bass, Drums, Other
309
+ - **5-stem**: + Piano
310
+ - **6-stem**: + Guitar
311
+
312
+ ### πŸ’‘ Tips
313
+ - Higher quality input = better separation
314
+ - Processing time varies by model and file length
315
+ - Results will appear below after processing
316
+ """)
317
+
318
+ # Status display
319
+ status_display = gr.Textbox(
320
+ label="Status",
321
+ interactive=False,
322
+ visible=True
323
+ )
324
+
325
+ # Dynamic audio outputs
326
+ stems_state = gr.State([])
327
+ audio_outputs = gr.Column(visible=False)
328
+
329
+ def process_and_display(audio_file, model_choice):
330
+ if not audio_file:
331
+ return [], "❌ Please upload an audio file", gr.Column(visible=False)
332
+
333
+ # Process the audio
334
+ stems, message = separate_stems_ui(audio_file, model_choice)
335
+
336
+ # Create audio components
337
+ if stems:
338
+ with gr.Column() as output_col:
339
+ gr.Markdown(f"### 🎢 Separated Stems ({len(stems)} files)")
340
+ for i, stem_path in enumerate(stems):
341
+ stem_name = Path(stem_path).stem.replace("_", " ").title()
342
+ gr.Audio(
343
+ value=stem_path,
344
+ label=f"🎡 {stem_name}",
345
+ show_download_button=True,
346
+ interactive=False
347
+ )
348
+ return stems, message, gr.Column(visible=True)
349
+ else:
350
+ return [], message, gr.Column(visible=False)
351
+
352
+ separate_btn.click(
353
+ fn=process_and_display,
354
+ inputs=[audio_input, model_dropdown],
355
+ outputs=[stems_state, status_display, audio_outputs],
356
+ show_progress=True
357
+ )
358
+
359
+ # Dependency check display
360
+ with gr.Accordion("πŸ”§ System Status", open=False):
361
+ def check_system():
362
+ deps = separator.check_dependencies()
363
+ status_text = "**Dependency Status:**\n"
364
+ for tool, available in deps.items():
365
+ status = "βœ… Available" if available else "❌ Not installed"
366
+ status_text += f"- {tool}: {status}\n"
367
+
368
+ if not any(deps.values()):
369
+ status_text += "\n**Installation Instructions:**\n"
370
+ status_text += "```bash\n"
371
+ status_text += "# Install Demucs (recommended)\n"
372
+ status_text += "pip install demucs\n\n"
373
+ status_text += "# Install Spleeter (alternative)\n"
374
+ status_text += "pip install spleeter tensorflow\n"
375
+ status_text += "```"
376
+
377
+ return status_text
378
+
379
+ system_status = gr.Markdown(value=check_system())
380
+
381
+ gr.Button("πŸ”„ Refresh Status").click(
382
+ fn=check_system,
383
+ outputs=system_status
384
+ )
385
+
386
+ return demo
387
+
388
+ # Launch the interface
389
+ if __name__ == "__main__":
390
+ demo = create_interface()
391
+ demo.launch(
392
+ server_name="0.0.0.0",
393
+ server_port=7860,
394
+ share=False,
395
+ show_error=True,
396
+ debug=True
397
+ )