# mdx_core.py import torch import numpy as np import onnxruntime as ort import hashlib import queue import threading from tqdm import tqdm class MDXModel: def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000): self.dim_f = dim_f self.dim_t = dim_t self.dim_c = 4 self.n_fft = n_fft self.hop = hop self.stem_name = stem_name self.compensation = compensation self.n_bins = self.n_fft // 2 + 1 self.chunk_size = hop * (self.dim_t - 1) self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device) self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins - self.dim_f, self.dim_t]).to(device) def stft(self, x): x = x.reshape([-1, self.chunk_size]) x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True) x = torch.view_as_real(x) x = x.permute([0, 3, 1, 2]) x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 4, self.n_bins, self.dim_t]) return x[:, :, :self.dim_f] def istft(self, x, freq_pad=None): freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad x = torch.cat([x, freq_pad], -2) x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t]) x = x.permute([0, 2, 3, 1]).contiguous() x = torch.view_as_complex(x) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) return x.reshape([-1, 2, self.chunk_size]) class MDX: DEFAULT_SR = 44100 DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR def __init__(self, model_path: str, params: MDXModel, processor=0): self.device = torch.device(f"cuda:{processor}" if processor >= 0 else "cpu") self.provider = ["CUDAExecutionProvider"] if processor >= 0 else ["CPUExecutionProvider"] self.model = params self.ort = ort.InferenceSession(model_path, providers=self.provider) self.ort.run(None, {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()}) self.process = lambda spec: self.ort.run(None, {"input": spec.cpu().numpy()})[0] self.prog = None @staticmethod def get_hash(model_path): try: with open(model_path, "rb") as f: f.seek(-10000 * 1024, 2) return hashlib.md5(f.read()).hexdigest() except: return hashlib.md5(open(model_path, "rb").read()).hexdigest() @staticmethod def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE): if combine: processed_wave = None for segment_count, segment in enumerate(wave): start = 0 if segment_count == 0 else margin_size end = None if segment_count == len(wave) - 1 else -margin_size if margin_size == 0: end = None part = segment[:, start:end] processed_wave = part if processed_wave is None else np.concatenate((processed_wave, part), axis=-1) else: processed_wave = [] sample_count = wave.shape[-1] if chunk_size <= 0 or chunk_size > sample_count: chunk_size = sample_count if margin_size > chunk_size: margin_size = chunk_size for segment_count, skip in enumerate(range(0, sample_count, chunk_size)): margin = 0 if segment_count == 0 else margin_size end = min(skip + chunk_size + margin_size, sample_count) start = skip - margin processed_wave.append(wave[:, start:end].copy()) if end == sample_count: break return processed_wave def pad_wave(self, wave): n_sample = wave.shape[1] trim = self.model.n_fft // 2 gen_size = self.model.chunk_size - 2 * trim pad = gen_size - n_sample % gen_size wave_p = np.concatenate((np.zeros((2, trim)), wave, np.zeros((2, pad)), np.zeros((2, trim))), 1) mix_waves = [torch.tensor(wave_p[:, i:i + self.model.chunk_size], dtype=torch.float32).to(self.device) for i in range(0, n_sample + pad, gen_size)] return torch.stack(mix_waves), pad, trim def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int): mix_waves = mix_waves.split(1) with torch.no_grad(): pw = [] for mix_wave in mix_waves: self.prog.update() spec = self.model.stft(mix_wave) processed_spec = torch.tensor(self.process(spec)) processed_wav = self.model.istft(processed_spec.to(self.device)) result = processed_wav[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).cpu().numpy() pw.append(result) q.put({_id: np.concatenate(pw, axis=-1)[:, :-pad]}) def process_wave(self, wave: np.array, mt_threads=1): self.prog = tqdm(total=0) chunk = wave.shape[-1] // mt_threads waves = self.segment(wave, False, chunk) q = queue.Queue() threads = [] for c, batch in enumerate(waves): mix_waves, pad, trim = self.pad_wave(batch) self.prog.total = len(mix_waves) * mt_threads thread = threading.Thread(target=self._process_wave, args=(mix_waves, trim, pad, q, c)) thread.start() threads.append(thread) for thread in threads: thread.join() self.prog.close() processed_batches = [q.get() for _ in range(len(waves))] processed_batches.sort(key=lambda d: list(d.keys())[0]) return self.segment([list(wave.values())[0] for wave in processed_batches], True, chunk)