Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import tempfile
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import shutil
|
6 |
+
from pathlib import Path
|
7 |
+
import logging
|
8 |
+
from typing import List, Tuple, Dict
|
9 |
+
import json
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class StemSeparator:
|
16 |
+
"""Modern stem separation with multiple model support"""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
self.supported_models = {
|
20 |
+
"htdemucs": {
|
21 |
+
"command": "demucs",
|
22 |
+
"stems": 4,
|
23 |
+
"description": "HTDemucs - High quality 4-stem separation"
|
24 |
+
},
|
25 |
+
"htdemucs_ft": {
|
26 |
+
"command": "demucs",
|
27 |
+
"model_name": "htdemucs_ft",
|
28 |
+
"stems": 4,
|
29 |
+
"description": "HTDemucs Fine-tuned - Enhanced 4-stem separation"
|
30 |
+
},
|
31 |
+
"htdemucs_6s": {
|
32 |
+
"command": "demucs",
|
33 |
+
"model_name": "htdemucs_6s",
|
34 |
+
"stems": 6,
|
35 |
+
"description": "HTDemucs 6-stem - Bass, Drums, Vocals, Other, Guitar, Piano"
|
36 |
+
},
|
37 |
+
"mdx": {
|
38 |
+
"command": "demucs",
|
39 |
+
"model_name": "mdx",
|
40 |
+
"stems": 4,
|
41 |
+
"description": "MDX - Optimized for vocal separation"
|
42 |
+
},
|
43 |
+
"mdx_extra": {
|
44 |
+
"command": "demucs",
|
45 |
+
"model_name": "mdx_extra",
|
46 |
+
"stems": 4,
|
47 |
+
"description": "MDX Extra - Enhanced vocal separation"
|
48 |
+
},
|
49 |
+
"spleeter_4stems": {
|
50 |
+
"command": "spleeter",
|
51 |
+
"model_name": "spleeter:4stems-waveform",
|
52 |
+
"stems": 4,
|
53 |
+
"description": "Spleeter 4-stem - Vocals, Bass, Drums, Other"
|
54 |
+
},
|
55 |
+
"spleeter_5stems": {
|
56 |
+
"command": "spleeter",
|
57 |
+
"model_name": "spleeter:5stems-waveform",
|
58 |
+
"stems": 5,
|
59 |
+
"description": "Spleeter 5-stem - Vocals, Bass, Drums, Piano, Other"
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
def check_dependencies(self) -> Dict[str, bool]:
|
64 |
+
"""Check if required tools are installed"""
|
65 |
+
dependencies = {}
|
66 |
+
|
67 |
+
# Check demucs
|
68 |
+
try:
|
69 |
+
result = subprocess.run(["python", "-m", "demucs", "--help"],
|
70 |
+
capture_output=True, text=True, timeout=10)
|
71 |
+
dependencies["demucs"] = result.returncode == 0
|
72 |
+
except (subprocess.TimeoutExpired, FileNotFoundError):
|
73 |
+
dependencies["demucs"] = False
|
74 |
+
|
75 |
+
# Check spleeter
|
76 |
+
try:
|
77 |
+
result = subprocess.run(["spleeter", "--help"],
|
78 |
+
capture_output=True, text=True, timeout=10)
|
79 |
+
dependencies["spleeter"] = result.returncode == 0
|
80 |
+
except (subprocess.TimeoutExpired, FileNotFoundError):
|
81 |
+
dependencies["spleeter"] = False
|
82 |
+
|
83 |
+
return dependencies
|
84 |
+
|
85 |
+
def separate_audio(self, audio_file: str, model_choice: str) -> Tuple[List[str], str]:
|
86 |
+
"""Separate audio into stems using the selected model"""
|
87 |
+
if not audio_file:
|
88 |
+
return [], "β No audio file provided"
|
89 |
+
|
90 |
+
if model_choice not in self.supported_models:
|
91 |
+
return [], f"β Unsupported model: {model_choice}"
|
92 |
+
|
93 |
+
model_config = self.supported_models[model_choice]
|
94 |
+
|
95 |
+
try:
|
96 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
97 |
+
temp_path = Path(temp_dir)
|
98 |
+
|
99 |
+
# Copy input file to temp directory with proper extension
|
100 |
+
input_file = Path(audio_file)
|
101 |
+
temp_input = temp_path / f"input{input_file.suffix}"
|
102 |
+
shutil.copy2(audio_file, temp_input)
|
103 |
+
|
104 |
+
logger.info(f"Processing {temp_input} with {model_choice}")
|
105 |
+
|
106 |
+
# Build command based on model type
|
107 |
+
if model_config["command"] == "demucs":
|
108 |
+
command = self._build_demucs_command(temp_input, temp_path, model_config)
|
109 |
+
elif model_config["command"] == "spleeter":
|
110 |
+
command = self._build_spleeter_command(temp_input, temp_path, model_config)
|
111 |
+
else:
|
112 |
+
return [], f"β Unknown command type: {model_config['command']}"
|
113 |
+
|
114 |
+
# Execute separation
|
115 |
+
logger.info(f"Running command: {' '.join(command)}")
|
116 |
+
result = subprocess.run(
|
117 |
+
command,
|
118 |
+
capture_output=True,
|
119 |
+
text=True,
|
120 |
+
timeout=300, # 5 minute timeout
|
121 |
+
cwd=temp_dir
|
122 |
+
)
|
123 |
+
|
124 |
+
if result.returncode != 0:
|
125 |
+
error_msg = f"β Separation failed: {result.stderr}"
|
126 |
+
logger.error(error_msg)
|
127 |
+
return [], error_msg
|
128 |
+
|
129 |
+
# Collect output stems
|
130 |
+
stems = self._collect_stems(temp_path, model_choice)
|
131 |
+
|
132 |
+
if not stems:
|
133 |
+
return [], "β No stems were generated"
|
134 |
+
|
135 |
+
success_msg = f"β
Successfully separated into {len(stems)} stems"
|
136 |
+
logger.info(success_msg)
|
137 |
+
return stems, success_msg
|
138 |
+
|
139 |
+
except subprocess.TimeoutExpired:
|
140 |
+
return [], "β Process timed out - file may be too large"
|
141 |
+
except Exception as e:
|
142 |
+
error_msg = f"β Error during separation: {str(e)}"
|
143 |
+
logger.error(error_msg)
|
144 |
+
return [], error_msg
|
145 |
+
|
146 |
+
def _build_demucs_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]:
|
147 |
+
"""Build demucs command"""
|
148 |
+
command = ["python", "-m", "demucs"]
|
149 |
+
|
150 |
+
if "model_name" in model_config:
|
151 |
+
command.extend(["-n", model_config["model_name"]])
|
152 |
+
|
153 |
+
command.extend([
|
154 |
+
"-o", str(output_dir),
|
155 |
+
"--filename", "{track}/{stem}.{ext}", # Organized output structure
|
156 |
+
str(input_file)
|
157 |
+
])
|
158 |
+
|
159 |
+
return command
|
160 |
+
|
161 |
+
def _build_spleeter_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]:
|
162 |
+
"""Build spleeter command"""
|
163 |
+
model_name = model_config.get("model_name", "spleeter:4stems-waveform")
|
164 |
+
|
165 |
+
command = [
|
166 |
+
"spleeter", "separate",
|
167 |
+
"-p", model_name,
|
168 |
+
"-o", str(output_dir),
|
169 |
+
"--filename_format", "{instrument}.{codec}",
|
170 |
+
str(input_file)
|
171 |
+
]
|
172 |
+
|
173 |
+
return command
|
174 |
+
|
175 |
+
def _collect_stems(self, output_dir: Path, model_choice: str) -> List[str]:
|
176 |
+
"""Collect generated stem files"""
|
177 |
+
stems = []
|
178 |
+
|
179 |
+
# Search for audio files in output directory
|
180 |
+
for audio_file in output_dir.rglob("*.wav"):
|
181 |
+
if audio_file.is_file() and audio_file.stat().st_size > 0:
|
182 |
+
# Copy to a permanent location that Gradio can access
|
183 |
+
permanent_path = self._copy_to_permanent_location(audio_file)
|
184 |
+
if permanent_path:
|
185 |
+
stems.append(permanent_path)
|
186 |
+
|
187 |
+
# Also check for other common audio formats
|
188 |
+
for ext in ["*.mp3", "*.flac", "*.m4a"]:
|
189 |
+
for audio_file in output_dir.rglob(ext):
|
190 |
+
if audio_file.is_file() and audio_file.stat().st_size > 0:
|
191 |
+
permanent_path = self._copy_to_permanent_location(audio_file)
|
192 |
+
if permanent_path:
|
193 |
+
stems.append(permanent_path)
|
194 |
+
|
195 |
+
return sorted(stems)
|
196 |
+
|
197 |
+
def _copy_to_permanent_location(self, temp_file: Path) -> str:
|
198 |
+
"""Copy temporary file to permanent location for Gradio"""
|
199 |
+
try:
|
200 |
+
# Create output directory if it doesn't exist
|
201 |
+
output_dir = Path("./separated_stems")
|
202 |
+
output_dir.mkdir(exist_ok=True)
|
203 |
+
|
204 |
+
# Generate unique filename
|
205 |
+
import time
|
206 |
+
timestamp = int(time.time() * 1000)
|
207 |
+
permanent_file = output_dir / f"{temp_file.stem}_{timestamp}{temp_file.suffix}"
|
208 |
+
|
209 |
+
shutil.copy2(temp_file, permanent_file)
|
210 |
+
return str(permanent_file)
|
211 |
+
except Exception as e:
|
212 |
+
logger.error(f"Failed to copy {temp_file}: {e}")
|
213 |
+
return None
|
214 |
+
|
215 |
+
# Initialize separator
|
216 |
+
separator = StemSeparator()
|
217 |
+
|
218 |
+
def get_available_models() -> List[Tuple[str, str]]:
|
219 |
+
"""Get list of available models based on installed dependencies"""
|
220 |
+
deps = separator.check_dependencies()
|
221 |
+
available_models = []
|
222 |
+
|
223 |
+
for model_id, config in separator.supported_models.items():
|
224 |
+
if config["command"] in deps and deps[config["command"]]:
|
225 |
+
label = f"{model_id} ({config['stems']} stems) - {config['description']}"
|
226 |
+
available_models.append((label, model_id))
|
227 |
+
|
228 |
+
if not available_models:
|
229 |
+
available_models = [("No models available - install demucs or spleeter", "none")]
|
230 |
+
|
231 |
+
return available_models
|
232 |
+
|
233 |
+
def separate_stems_ui(audio_file: str, model_choice: str) -> Tuple[List[str], str]:
|
234 |
+
"""UI wrapper for stem separation"""
|
235 |
+
if model_choice == "none":
|
236 |
+
return [], "β Please install demucs and/or spleeter first"
|
237 |
+
|
238 |
+
stems, message = separator.separate_audio(audio_file, model_choice)
|
239 |
+
return stems, message
|
240 |
+
|
241 |
+
def create_audio_gallery(stems: List[str]) -> List[gr.Audio]:
|
242 |
+
"""Create audio components for each stem"""
|
243 |
+
if not stems:
|
244 |
+
return []
|
245 |
+
|
246 |
+
audio_components = []
|
247 |
+
for i, stem_path in enumerate(stems):
|
248 |
+
stem_name = Path(stem_path).stem
|
249 |
+
audio_comp = gr.Audio(
|
250 |
+
value=stem_path,
|
251 |
+
label=f"Stem {i+1}: {stem_name}",
|
252 |
+
interactive=False,
|
253 |
+
show_download_button=True
|
254 |
+
)
|
255 |
+
audio_components.append(audio_comp)
|
256 |
+
|
257 |
+
return audio_components
|
258 |
+
|
259 |
+
# Create Gradio interface
|
260 |
+
def create_interface():
|
261 |
+
with gr.Blocks(
|
262 |
+
title="π΅ Advanced Music Stem Separator",
|
263 |
+
theme=gr.themes.Soft(),
|
264 |
+
css="""
|
265 |
+
.audio-container { margin: 10px 0; }
|
266 |
+
.status-success { color: #22c55e; font-weight: bold; }
|
267 |
+
.status-error { color: #ef4444; font-weight: bold; }
|
268 |
+
"""
|
269 |
+
) as demo:
|
270 |
+
|
271 |
+
gr.Markdown("""
|
272 |
+
# π΅ Advanced Music Stem Separator
|
273 |
+
|
274 |
+
Separate music into individual stems (vocals, instruments, etc.) using state-of-the-art AI models.
|
275 |
+
Supports up to 6 stems depending on the model chosen.
|
276 |
+
|
277 |
+
**Supported Models:**
|
278 |
+
- **Demucs Models**: HTDemucs, HTDemucs-FT, HTDemucs-6s, MDX, MDX-Extra
|
279 |
+
- **Spleeter Models**: 4-stem and 5-stem separation
|
280 |
+
|
281 |
+
**Requirements**: Install `demucs` and/or `spleeter` packages
|
282 |
+
""")
|
283 |
+
|
284 |
+
with gr.Row():
|
285 |
+
with gr.Column(scale=2):
|
286 |
+
audio_input = gr.Audio(
|
287 |
+
type="filepath",
|
288 |
+
label="πΌ Upload Audio File",
|
289 |
+
info="Supported formats: WAV, MP3, FLAC, M4A"
|
290 |
+
)
|
291 |
+
|
292 |
+
model_dropdown = gr.Dropdown(
|
293 |
+
choices=get_available_models(),
|
294 |
+
value=get_available_models()[0][1] if get_available_models() else "none",
|
295 |
+
label="π§ Separation Model",
|
296 |
+
info="Choose the AI model for stem separation"
|
297 |
+
)
|
298 |
+
|
299 |
+
separate_btn = gr.Button(
|
300 |
+
"ποΈ Separate Stems",
|
301 |
+
variant="primary",
|
302 |
+
size="lg"
|
303 |
+
)
|
304 |
+
|
305 |
+
with gr.Column(scale=1):
|
306 |
+
gr.Markdown("""
|
307 |
+
### βΉοΈ Model Info
|
308 |
+
- **4-stem**: Vocals, Bass, Drums, Other
|
309 |
+
- **5-stem**: + Piano
|
310 |
+
- **6-stem**: + Guitar
|
311 |
+
|
312 |
+
### π‘ Tips
|
313 |
+
- Higher quality input = better separation
|
314 |
+
- Processing time varies by model and file length
|
315 |
+
- Results will appear below after processing
|
316 |
+
""")
|
317 |
+
|
318 |
+
# Status display
|
319 |
+
status_display = gr.Textbox(
|
320 |
+
label="Status",
|
321 |
+
interactive=False,
|
322 |
+
visible=True
|
323 |
+
)
|
324 |
+
|
325 |
+
# Dynamic audio outputs
|
326 |
+
stems_state = gr.State([])
|
327 |
+
audio_outputs = gr.Column(visible=False)
|
328 |
+
|
329 |
+
def process_and_display(audio_file, model_choice):
|
330 |
+
if not audio_file:
|
331 |
+
return [], "β Please upload an audio file", gr.Column(visible=False)
|
332 |
+
|
333 |
+
# Process the audio
|
334 |
+
stems, message = separate_stems_ui(audio_file, model_choice)
|
335 |
+
|
336 |
+
# Create audio components
|
337 |
+
if stems:
|
338 |
+
with gr.Column() as output_col:
|
339 |
+
gr.Markdown(f"### πΆ Separated Stems ({len(stems)} files)")
|
340 |
+
for i, stem_path in enumerate(stems):
|
341 |
+
stem_name = Path(stem_path).stem.replace("_", " ").title()
|
342 |
+
gr.Audio(
|
343 |
+
value=stem_path,
|
344 |
+
label=f"π΅ {stem_name}",
|
345 |
+
show_download_button=True,
|
346 |
+
interactive=False
|
347 |
+
)
|
348 |
+
return stems, message, gr.Column(visible=True)
|
349 |
+
else:
|
350 |
+
return [], message, gr.Column(visible=False)
|
351 |
+
|
352 |
+
separate_btn.click(
|
353 |
+
fn=process_and_display,
|
354 |
+
inputs=[audio_input, model_dropdown],
|
355 |
+
outputs=[stems_state, status_display, audio_outputs],
|
356 |
+
show_progress=True
|
357 |
+
)
|
358 |
+
|
359 |
+
# Dependency check display
|
360 |
+
with gr.Accordion("π§ System Status", open=False):
|
361 |
+
def check_system():
|
362 |
+
deps = separator.check_dependencies()
|
363 |
+
status_text = "**Dependency Status:**\n"
|
364 |
+
for tool, available in deps.items():
|
365 |
+
status = "β
Available" if available else "β Not installed"
|
366 |
+
status_text += f"- {tool}: {status}\n"
|
367 |
+
|
368 |
+
if not any(deps.values()):
|
369 |
+
status_text += "\n**Installation Instructions:**\n"
|
370 |
+
status_text += "```bash\n"
|
371 |
+
status_text += "# Install Demucs (recommended)\n"
|
372 |
+
status_text += "pip install demucs\n\n"
|
373 |
+
status_text += "# Install Spleeter (alternative)\n"
|
374 |
+
status_text += "pip install spleeter tensorflow\n"
|
375 |
+
status_text += "```"
|
376 |
+
|
377 |
+
return status_text
|
378 |
+
|
379 |
+
system_status = gr.Markdown(value=check_system())
|
380 |
+
|
381 |
+
gr.Button("π Refresh Status").click(
|
382 |
+
fn=check_system,
|
383 |
+
outputs=system_status
|
384 |
+
)
|
385 |
+
|
386 |
+
return demo
|
387 |
+
|
388 |
+
# Launch the interface
|
389 |
+
if __name__ == "__main__":
|
390 |
+
demo = create_interface()
|
391 |
+
demo.launch(
|
392 |
+
server_name="0.0.0.0",
|
393 |
+
server_port=7860,
|
394 |
+
share=False,
|
395 |
+
show_error=True,
|
396 |
+
debug=True
|
397 |
+
)
|