masszhou commited on
Commit
17d9938
·
1 Parent(s): 2e4d768

Add application file

Browse files
Files changed (4) hide show
  1. app.py +647 -0
  2. mdx_models/model_data.json +50 -0
  3. pyproject.toml +30 -0
  4. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.system("pip install ./ort_nightly_gpu-1.17.0.dev20240118002-cp310-cp310-manylinux_2_28_x86_64.whl")
3
+ os.system("pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/")
4
+ import gc
5
+ import hashlib
6
+ import queue
7
+ import threading
8
+ import json
9
+ import shlex
10
+ import sys
11
+ import subprocess
12
+ import librosa
13
+ import numpy as np
14
+ import soundfile as sf
15
+ import torch
16
+ from tqdm import tqdm
17
+ import random
18
+ import spaces
19
+ import onnxruntime as ort
20
+ import warnings
21
+ import spaces
22
+ import gradio as gr
23
+ import logging
24
+ import time
25
+ import traceback
26
+ import numpy as np
27
+ import yt_dlp
28
+ from pathlib import Path
29
+ from huggingface_hub import hf_hub_download
30
+ from typing import Dict, Tuple
31
+
32
+
33
+ MODEL_ID = "masszhou/mdxnet"
34
+ MODELS_PATH = {
35
+ "bgm": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR-MDX-NET-Inst_HQ_3.onnx")),
36
+ "basic_vocal": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR-MDX-NET-Voc_FT.onnx")),
37
+ "main_vocal": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR_MDXNET_KARA_2.onnx"))
38
+ }
39
+
40
+
41
+ STEM_NAMING = {
42
+ "Vocals": "Instrumental",
43
+ "Other": "Instruments",
44
+ "Instrumental": "Vocals",
45
+ "Drums": "Drumless",
46
+ "Bass": "Bassless",
47
+ }
48
+
49
+ class MDXModel:
50
+ def __init__(
51
+ self,
52
+ device,
53
+ dim_f,
54
+ dim_t,
55
+ n_fft,
56
+ hop=1024,
57
+ stem_name=None,
58
+ compensation=1.000,
59
+ ):
60
+ self.dim_f = dim_f
61
+ self.dim_t = dim_t
62
+ self.dim_c = 4
63
+ self.n_fft = n_fft
64
+ self.hop = hop
65
+ self.stem_name = stem_name
66
+ self.compensation = compensation
67
+
68
+ self.n_bins = self.n_fft // 2 + 1
69
+ self.chunk_size = hop * (self.dim_t - 1)
70
+ self.window = torch.hann_window(
71
+ window_length=self.n_fft, periodic=True
72
+ ).to(device)
73
+
74
+ out_c = self.dim_c
75
+
76
+ self.freq_pad = torch.zeros(
77
+ [1, out_c, self.n_bins - self.dim_f, self.dim_t]
78
+ ).to(device)
79
+
80
+ def stft(self, x):
81
+ x = x.reshape([-1, self.chunk_size])
82
+ x = torch.stft(
83
+ x,
84
+ n_fft=self.n_fft,
85
+ hop_length=self.hop,
86
+ window=self.window,
87
+ center=True,
88
+ return_complex=True,
89
+ )
90
+ x = torch.view_as_real(x)
91
+ x = x.permute([0, 3, 1, 2])
92
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
93
+ [-1, 4, self.n_bins, self.dim_t]
94
+ )
95
+ return x[:, :, : self.dim_f]
96
+
97
+ def istft(self, x, freq_pad=None):
98
+ freq_pad = (
99
+ self.freq_pad.repeat([x.shape[0], 1, 1, 1])
100
+ if freq_pad is None
101
+ else freq_pad
102
+ )
103
+ x = torch.cat([x, freq_pad], -2)
104
+ # c = 4*2 if self.target_name=='*' else 2
105
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
106
+ [-1, 2, self.n_bins, self.dim_t]
107
+ )
108
+ x = x.permute([0, 2, 3, 1])
109
+ x = x.contiguous()
110
+ x = torch.view_as_complex(x)
111
+ x = torch.istft(
112
+ x,
113
+ n_fft=self.n_fft,
114
+ hop_length=self.hop,
115
+ window=self.window,
116
+ center=True,
117
+ )
118
+ return x.reshape([-1, 2, self.chunk_size])
119
+
120
+
121
+ class MDX:
122
+ DEFAULT_SR = 44100
123
+ # Unit: seconds
124
+ DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
125
+ DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
126
+
127
+ def __init__(
128
+ self, model_path: str, params: MDXModel, processor=0
129
+ ):
130
+ # Set the device and the provider (CPU or CUDA)
131
+ self.device = (
132
+ torch.device(f"cuda:{processor}")
133
+ if processor >= 0
134
+ else torch.device("cpu")
135
+ )
136
+ self.provider = (
137
+ ["CUDAExecutionProvider"]
138
+ if processor >= 0
139
+ else ["CPUExecutionProvider"]
140
+ )
141
+
142
+ self.model = params
143
+
144
+ # Load the ONNX model using ONNX Runtime
145
+ self.ort = ort.InferenceSession(model_path, providers=self.provider)
146
+ # Preload the model for faster performance
147
+ self.ort.run(
148
+ None,
149
+ {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()},
150
+ )
151
+ self.process = lambda spec: self.ort.run(
152
+ None, {"input": spec.cpu().numpy()}
153
+ )[0]
154
+
155
+ self.prog = None
156
+
157
+ @staticmethod
158
+ def get_hash(model_path):
159
+ try:
160
+ with open(model_path, "rb") as f:
161
+ f.seek(-10000 * 1024, 2)
162
+ model_hash = hashlib.md5(f.read()).hexdigest()
163
+ except: # noqa
164
+ model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest()
165
+
166
+ return model_hash
167
+
168
+ @staticmethod
169
+ def segment(
170
+ wave,
171
+ combine=True,
172
+ chunk_size=DEFAULT_CHUNK_SIZE,
173
+ margin_size=DEFAULT_MARGIN_SIZE,
174
+ ):
175
+ """
176
+ Segment or join segmented wave array
177
+ Args:
178
+ wave: (np.array) Wave array to be segmented or joined
179
+ combine: (bool) If True, combines segmented wave array.
180
+ If False, segments wave array.
181
+ chunk_size: (int) Size of each segment (in samples)
182
+ margin_size: (int) Size of margin between segments (in samples)
183
+ Returns:
184
+ numpy array: Segmented or joined wave array
185
+ """
186
+
187
+ if combine:
188
+ # Initializing as None instead of [] for later numpy array concatenation
189
+ processed_wave = None
190
+ for segment_count, segment in enumerate(wave):
191
+ start = 0 if segment_count == 0 else margin_size
192
+ end = None if segment_count == len(wave) - 1 else -margin_size
193
+ if margin_size == 0:
194
+ end = None
195
+ if processed_wave is None: # Create array for first segment
196
+ processed_wave = segment[:, start:end]
197
+ else: # Concatenate to existing array for subsequent segments
198
+ processed_wave = np.concatenate(
199
+ (processed_wave, segment[:, start:end]), axis=-1
200
+ )
201
+
202
+ else:
203
+ processed_wave = []
204
+ sample_count = wave.shape[-1]
205
+
206
+ if chunk_size <= 0 or chunk_size > sample_count:
207
+ chunk_size = sample_count
208
+
209
+ if margin_size > chunk_size:
210
+ margin_size = chunk_size
211
+
212
+ for segment_count, skip in enumerate(
213
+ range(0, sample_count, chunk_size)
214
+ ):
215
+ margin = 0 if segment_count == 0 else margin_size
216
+ end = min(skip + chunk_size + margin_size, sample_count)
217
+ start = skip - margin
218
+
219
+ cut = wave[:, start:end].copy()
220
+ processed_wave.append(cut)
221
+
222
+ if end == sample_count:
223
+ break
224
+
225
+ return processed_wave
226
+
227
+ def pad_wave(self, wave):
228
+ """
229
+ Pad the wave array to match the required chunk size
230
+ Args:
231
+ wave: (np.array) Wave array to be padded
232
+ Returns:
233
+ tuple: (padded_wave, pad, trim)
234
+ - padded_wave: Padded wave array
235
+ - pad: Number of samples that were padded
236
+ - trim: Number of samples that were trimmed
237
+ """
238
+ n_sample = wave.shape[1]
239
+ trim = self.model.n_fft // 2
240
+ gen_size = self.model.chunk_size - 2 * trim
241
+ pad = gen_size - n_sample % gen_size
242
+
243
+ # Padded wave
244
+ wave_p = np.concatenate(
245
+ (
246
+ np.zeros((2, trim)),
247
+ wave,
248
+ np.zeros((2, pad)),
249
+ np.zeros((2, trim)),
250
+ ),
251
+ 1,
252
+ )
253
+
254
+ mix_waves = []
255
+ for i in range(0, n_sample + pad, gen_size):
256
+ waves = np.array(wave_p[:, i:i + self.model.chunk_size])
257
+ mix_waves.append(waves)
258
+
259
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
260
+ self.device
261
+ )
262
+
263
+ return mix_waves, pad, trim
264
+
265
+ def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
266
+ """
267
+ Process each wave segment in a multi-threaded environment
268
+ Args:
269
+ mix_waves: (torch.Tensor) Wave segments to be processed
270
+ trim: (int) Number of samples trimmed during padding
271
+ pad: (int) Number of samples padded during padding
272
+ q: (queue.Queue) Queue to hold the processed wave segments
273
+ _id: (int) Identifier of the processed wave segment
274
+ Returns:
275
+ numpy array: Processed wave segment
276
+ """
277
+ mix_waves = mix_waves.split(1)
278
+ with torch.no_grad():
279
+ pw = []
280
+ for mix_wave in mix_waves:
281
+ self.prog.update()
282
+ spec = self.model.stft(mix_wave)
283
+ processed_spec = torch.tensor(self.process(spec))
284
+ processed_wav = self.model.istft(
285
+ processed_spec.to(self.device)
286
+ )
287
+ processed_wav = (
288
+ processed_wav[:, :, trim:-trim]
289
+ .transpose(0, 1)
290
+ .reshape(2, -1)
291
+ .cpu()
292
+ .numpy()
293
+ )
294
+ pw.append(processed_wav)
295
+ processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
296
+ q.put({_id: processed_signal})
297
+ return processed_signal
298
+
299
+ def process_wave(self, wave: np.array, mt_threads=1):
300
+ """
301
+ Process the wave array in a multi-threaded environment
302
+ Args:
303
+ wave: (np.array) Wave array to be processed
304
+ mt_threads: (int) Number of threads to be used for processing
305
+ Returns:
306
+ numpy array: Processed wave array
307
+ """
308
+ self.prog = tqdm(total=0)
309
+ chunk = wave.shape[-1] // mt_threads
310
+ waves = self.segment(wave, False, chunk)
311
+
312
+ # Create a queue to hold the processed wave segments
313
+ q = queue.Queue()
314
+ threads = []
315
+ for c, batch in enumerate(waves):
316
+ mix_waves, pad, trim = self.pad_wave(batch)
317
+ self.prog.total = len(mix_waves) * mt_threads
318
+ thread = threading.Thread(
319
+ target=self._process_wave, args=(mix_waves, trim, pad, q, c)
320
+ )
321
+ thread.start()
322
+ threads.append(thread)
323
+ for thread in threads:
324
+ thread.join()
325
+ self.prog.close()
326
+
327
+ processed_batches = []
328
+ while not q.empty():
329
+ processed_batches.append(q.get())
330
+ processed_batches = [
331
+ list(wave.values())[0]
332
+ for wave in sorted(
333
+ processed_batches, key=lambda d: list(d.keys())[0]
334
+ )
335
+ ]
336
+ assert len(processed_batches) == len(
337
+ waves
338
+ ), "Incomplete processed batches, please reduce batch size!"
339
+ return self.segment(processed_batches, True, chunk)
340
+
341
+
342
+ @spaces.GPU()
343
+ def run_mdx(
344
+ model_params,
345
+ output_dir,
346
+ model_path,
347
+ filename,
348
+ exclude_main=False,
349
+ exclude_inversion=False,
350
+ suffix=None,
351
+ invert_suffix=None,
352
+ denoise=False,
353
+ keep_orig=True,
354
+ m_threads=2,
355
+ device_base="cuda",
356
+ ):
357
+
358
+ if device_base == "cuda":
359
+ device = torch.device("cuda:0")
360
+ processor_num = 0
361
+ device_properties = torch.cuda.get_device_properties(device)
362
+ vram_gb = device_properties.total_memory / 1024**3
363
+ m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2)
364
+ else:
365
+ device = torch.device("cpu")
366
+ processor_num = -1
367
+ m_threads = 1
368
+
369
+ model_hash = MDX.get_hash(model_path)
370
+ mp = model_params.get(model_hash)
371
+ model = MDXModel(
372
+ device,
373
+ dim_f=mp["mdx_dim_f_set"],
374
+ dim_t=2 ** mp["mdx_dim_t_set"],
375
+ n_fft=mp["mdx_n_fft_scale_set"],
376
+ stem_name=mp["primary_stem"],
377
+ compensation=mp["compensate"],
378
+ )
379
+
380
+ mdx_sess = MDX(model_path, model, processor=processor_num)
381
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
382
+ # normalizing input wave gives better output
383
+ peak = max(np.max(wave), abs(np.min(wave)))
384
+ wave /= peak
385
+ if denoise:
386
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
387
+ mdx_sess.process_wave(wave, m_threads)
388
+ )
389
+ wave_processed *= 0.5
390
+ else:
391
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
392
+ # return to previous peak
393
+ wave_processed *= peak
394
+ stem_name = model.stem_name if suffix is None else suffix
395
+
396
+ main_filepath = None
397
+ if not exclude_main:
398
+ main_filepath = os.path.join(
399
+ output_dir,
400
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
401
+ )
402
+ sf.write(main_filepath, wave_processed.T, sr)
403
+
404
+ invert_filepath = None
405
+ if not exclude_inversion:
406
+ diff_stem_name = (
407
+ stem_naming.get(stem_name)
408
+ if invert_suffix is None
409
+ else invert_suffix
410
+ )
411
+ stem_name = (
412
+ f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
413
+ )
414
+ invert_filepath = os.path.join(
415
+ output_dir,
416
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
417
+ )
418
+ sf.write(
419
+ invert_filepath,
420
+ (-wave_processed.T * model.compensation) + wave.T,
421
+ sr,
422
+ )
423
+
424
+ if not keep_orig:
425
+ os.remove(filename)
426
+
427
+ del mdx_sess, wave_processed, wave
428
+ gc.collect()
429
+ torch.cuda.empty_cache()
430
+ return main_filepath, invert_filepath
431
+
432
+
433
+ def run_mdx_beta(
434
+ model_params,
435
+ output_dir,
436
+ model_path,
437
+ filename,
438
+ exclude_main=False,
439
+ exclude_inversion=False,
440
+ suffix=None,
441
+ invert_suffix=None,
442
+ denoise=False,
443
+ keep_orig=True,
444
+ m_threads=2,
445
+ device_base="",
446
+ ):
447
+
448
+ m_threads = 1
449
+ duration = librosa.get_duration(filename=filename)
450
+ if duration >= 60 and duration <= 120:
451
+ m_threads = 8
452
+ elif duration > 120:
453
+ m_threads = 16
454
+
455
+ model_hash = MDX.get_hash(model_path)
456
+ device = torch.device("cpu")
457
+ processor_num = -1
458
+ mp = model_params.get(model_hash)
459
+ model = MDXModel(
460
+ device,
461
+ dim_f=mp["mdx_dim_f_set"],
462
+ dim_t=2 ** mp["mdx_dim_t_set"],
463
+ n_fft=mp["mdx_n_fft_scale_set"],
464
+ stem_name=mp["primary_stem"],
465
+ compensation=mp["compensate"],
466
+ )
467
+
468
+ mdx_sess = MDX(model_path, model, processor=processor_num)
469
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
470
+ # normalizing input wave gives better output
471
+ peak = max(np.max(wave), abs(np.min(wave)))
472
+ wave /= peak
473
+ if denoise:
474
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
475
+ mdx_sess.process_wave(wave, m_threads)
476
+ )
477
+ wave_processed *= 0.5
478
+ else:
479
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
480
+ # return to previous peak
481
+ wave_processed *= peak
482
+ stem_name = model.stem_name if suffix is None else suffix
483
+
484
+ main_filepath = None
485
+ if not exclude_main:
486
+ main_filepath = os.path.join(
487
+ output_dir,
488
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
489
+ )
490
+ sf.write(main_filepath, wave_processed.T, sr)
491
+
492
+ invert_filepath = None
493
+ if not exclude_inversion:
494
+ diff_stem_name = (
495
+ stem_naming.get(stem_name)
496
+ if invert_suffix is None
497
+ else invert_suffix
498
+ )
499
+ stem_name = (
500
+ f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
501
+ )
502
+ invert_filepath = os.path.join(
503
+ output_dir,
504
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
505
+ )
506
+ sf.write(
507
+ invert_filepath,
508
+ (-wave_processed.T * model.compensation) + wave.T,
509
+ sr,
510
+ )
511
+
512
+ if not keep_orig:
513
+ os.remove(filename)
514
+
515
+ del mdx_sess, wave_processed, wave
516
+ gc.collect()
517
+ torch.cuda.empty_cache()
518
+ return main_filepath, invert_filepath
519
+
520
+
521
+ def extract_bgm(mdx_model_params: Dict,
522
+ input_filename: Path,
523
+ model_bgm_path: Path,
524
+ output_dir: Path,
525
+ device_base: str = "cuda") -> Path:
526
+ """
527
+ Extract pure background music, remove vocals
528
+ """
529
+ background_path, _ = run_mdx(model_params=mdx_model_params,
530
+ input_filename=input_filename,
531
+ output_dir=output_dir,
532
+ model_path=model_bgm_path,
533
+ denoise=False,
534
+ device_base=device_base,
535
+ )
536
+ return background_path
537
+
538
+
539
+ def extract_vocal(mdx_model_params: Dict,
540
+ input_filename: Path,
541
+ model_basic_vocal_path: Path,
542
+ model_main_vocal_path: Path,
543
+ output_dir: Path,
544
+ main_vocals_flag: bool = False,
545
+ device_base: str = "cuda") -> Path:
546
+ """
547
+ Extract vocals
548
+ """
549
+ # First use UVR-MDX-NET-Voc_FT.onnx basic vocal separation model
550
+ vocals_path, _ = run_mdx(mdx_model_params,
551
+ input_filename,
552
+ output_dir,
553
+ model_basic_vocal_path,
554
+ denoise=True,
555
+ device_base=device_base,
556
+ )
557
+ # If "main_vocals_flag" is enabled, use UVR_MDXNET_KARA_2.onnx to further separate main vocals (Main) from backup vocals/background vocals (Backup)
558
+ if main_vocals_flag:
559
+ time.sleep(2)
560
+ backup_vocals_path, main_vocals_path = run_mdx(mdx_model_params,
561
+ output_dir,
562
+ model_main_vocal_path,
563
+ vocals_path,
564
+ denoise=True,
565
+ device_base=device_base,
566
+ )
567
+ vocals_path = main_vocals_path
568
+
569
+ return vocals_path
570
+
571
+
572
+ def process_uvr_task(input_file_path: Path,
573
+ output_dir: Path,
574
+ models_path: Dict[str, Path],
575
+ main_vocals_flag: bool = False, # If "Main" is enabled, use UVR_MDXNET_KARA_2.onnx to further separate main and backup vocals
576
+ ) -> Tuple[Path, Path]:
577
+
578
+ device_base = "cuda" if torch.cuda.is_available() else "cpu"
579
+
580
+ # load mdx model definition
581
+ with open("./mdx_models/model_data.json") as infile:
582
+ mdx_model_params = json.load(infile) # type: Dict
583
+
584
+ output_dir.mkdir(parents=True, exist_ok=True)
585
+ input_file_path = convert_to_stereo_and_wav(input_file_path) # type: Path
586
+
587
+ # 1. Extract pure background music, remove vocals
588
+ background_path = extract_bgm(mdx_model_params,
589
+ input_file_path,
590
+ models_path["bgm"],
591
+ output_dir,
592
+ device_base=device_base)
593
+
594
+ # 2. Separate vocals
595
+ # First use UVR-MDX-NET-Voc_FT.onnx basic vocal separation model
596
+ vocals_path = extract_vocal(mdx_model_params,
597
+ input_file_path,
598
+ models_path["basic_vocal"],
599
+ models_path["main_vocal"],
600
+ output_dir,
601
+ main_vocals_flag=main_vocals_flag,
602
+ device_base=device_base)
603
+
604
+ return background_path, vocals_path
605
+
606
+
607
+ def get_model_params(model_path: Path) -> Dict:
608
+ """
609
+ Get model parameters from model path
610
+ """
611
+ with open(model_path / "model_data.json") as infile:
612
+ return json.load(infile) # type: Dict
613
+
614
+
615
+ def inference_mdx(audio_file: str) -> list[str]:
616
+ mdx_model_params = get_model_params(Path("./mdx_models"))
617
+ audio_file = convert_to_stereo_and_wav(Path(audio_file)) # resampling at 44100 Hz
618
+ device_base = "cuda" if torch.cuda.is_available() else "cpu"
619
+ output_dir = Path("./out/mdx")
620
+ os.makedirs(output_dir, exist_ok=True)
621
+ model_bgm_path = MODELS_PATH["bgm"]
622
+ background_path, vocal_path = run_mdx(
623
+ model_params=mdx_model_params,
624
+ input_filename=audio_file,
625
+ output_dir=output_dir,
626
+ model_path=model_bgm_path,
627
+ denoise=False,
628
+ device_base=device_base,
629
+ )
630
+
631
+ return str(vocal_path), str(background_path)
632
+
633
+
634
+ if __name__ == "__main__":
635
+ # zero = torch.Tensor([0]).cuda()
636
+ # print(f"zero.device: {zero.device}")
637
+
638
+ app = gr.Interface(
639
+ fn = inference_mdx,
640
+ inputs = gr.Audio(type="filepath", label="Input"),
641
+ outputs = [gr.Audio(type="filepath", label="Vocals"),gr.Audio(type="filepath", label="BGM")],
642
+ title="MDXNET Music Source Separation",
643
+ article="<p style='text-align: center'><a href='https://arxiv.org/abs/2111.12203' target='_blank'>KUIELab-MDX-Net: A Two-Stream Neural Network for Music Demixing</a> | <a href='https://github.com/kuielab/mdx-net' target='_blank'>Github Repo</a> | <a href='https://github.com/kuielab/mdx-net/blob/main/LICENSE' target='_blank'>MIT License</a></p>",
644
+ api_name="mdxnet_separation",
645
+ )
646
+
647
+ app.launch()
mdx_models/model_data.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "77d07b2667ddf05b9e3175941b4454a0": {
3
+ "compensate": 1.021,
4
+ "mdx_dim_f_set": 3072,
5
+ "mdx_dim_t_set": 8,
6
+ "mdx_n_fft_scale_set": 7680,
7
+ "primary_stem": "Vocals",
8
+ "name": "UVR-MDX-NET-Voc_FT.onnx"
9
+ },
10
+ "1d64a6d2c30f709b8c9b4ce1366d96ee": {
11
+ "compensate": 1.035,
12
+ "mdx_dim_f_set": 2048,
13
+ "mdx_dim_t_set": 8,
14
+ "mdx_n_fft_scale_set": 5120,
15
+ "primary_stem": "Instrumental",
16
+ "name": "UVR_MDXNET_KARA_2.onnx"
17
+ },
18
+ "cd5b2989ad863f116c855db1dfe24e39": {
19
+ "compensate": 1.035,
20
+ "mdx_dim_f_set": 3072,
21
+ "mdx_dim_t_set": 9,
22
+ "mdx_n_fft_scale_set": 6144,
23
+ "primary_stem": "Other",
24
+ "name": "Reverb_HQ_By_FoxJoy.onnx"
25
+ },
26
+ "55657dd70583b0fedfba5f67df11d711": {
27
+ "compensate": 1.022,
28
+ "mdx_dim_f_set": 3072,
29
+ "mdx_dim_t_set": 8,
30
+ "mdx_n_fft_scale_set": 6144,
31
+ "primary_stem": "Instrumental",
32
+ "name": "UVR-MDX-NET-Inst_HQ_3.onnx"
33
+ },
34
+ "cc63408db3d80b4d85b0287d1d7c9632": {
35
+ "compensate": 1.033,
36
+ "mdx_dim_f_set": 3072,
37
+ "mdx_dim_t_set": 8,
38
+ "mdx_n_fft_scale_set": 6144,
39
+ "primary_stem": "Instrumental",
40
+ "name": "UVR-MDX-NET-Inst_HQ_2.onnx"
41
+ },
42
+ "0f2a6bc5b49d87d64728ee40e23bceb1": {
43
+ "compensate": 1.022,
44
+ "mdx_dim_f_set": 3072,
45
+ "mdx_dim_t_set": 8,
46
+ "mdx_n_fft_scale_set": 6144,
47
+ "primary_stem": "Instrumental",
48
+ "name": "UVR-MDX-NET-Inst_HQ_4.onnx"
49
+ }
50
+ }
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "bgmseparatorgpu"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Zhiliang Zhou <[email protected]>"]
6
+ readme = "README.md"
7
+ package-mode = false
8
+
9
+ [tool.poetry.dependencies]
10
+ python = ">=3.11,<3.13"
11
+ gradio = "4.42.0"
12
+ pydantic = "2.8.2"
13
+ fastapi = "0.112.2"
14
+ scipy = "^1.15.2"
15
+ numpy = "^2.2.4"
16
+ onnxruntime = "^1.21.0"
17
+ torch = "^2.6.0"
18
+ tqdm = "^4.67.1"
19
+ librosa = "^0.11.0"
20
+ soundfile = "^0.13.1"
21
+ spaces = "^0.34.2"
22
+ huggingface-hub = "^0.30.2"
23
+
24
+
25
+ [build-system]
26
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
27
+ build-backend = "poetry.core.masonry.api"
28
+ jupyter = "^1.1.1"
29
+ qtconsole = "^5.6.1"
30
+ pyqt5 = "^5.15.11"
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ soundfile
2
+ librosa
3
+ torch==2.2.0
4
+ pedalboard
5
+ yt-dlp
6
+ gradio==4.42.0
7
+ pydantic==2.8.2
8
+ fastapi==0.112.2
9
+ scipy
10
+ numpy