Spaces:
Sleeping
Sleeping
File size: 5,988 Bytes
a281b7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# mdx_core.py
import torch
import numpy as np
import onnxruntime as ort
import hashlib
import queue
import threading
from tqdm import tqdm
class MDXModel:
def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000):
self.dim_f = dim_f
self.dim_t = dim_t
self.dim_c = 4
self.n_fft = n_fft
self.hop = hop
self.stem_name = stem_name
self.compensation = compensation
self.n_bins = self.n_fft // 2 + 1
self.chunk_size = hop * (self.dim_t - 1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window,
center=True, return_complex=True)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 4, self.n_bins, self.dim_t])
return x[:, :, :self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
x = x.permute([0, 2, 3, 1]).contiguous()
x = torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1, 2, self.chunk_size])
class MDX:
DEFAULT_SR = 44100
DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
def __init__(self, model_path: str, params: MDXModel, processor=0):
self.device = torch.device(f"cuda:{processor}" if processor >= 0 else "cpu")
self.provider = ["CUDAExecutionProvider"] if processor >= 0 else ["CPUExecutionProvider"]
self.model = params
self.ort = ort.InferenceSession(model_path, providers=self.provider)
self.ort.run(None, {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()})
self.process = lambda spec: self.ort.run(None, {"input": spec.cpu().numpy()})[0]
self.prog = None
@staticmethod
def get_hash(model_path):
try:
with open(model_path, "rb") as f:
f.seek(-10000 * 1024, 2)
return hashlib.md5(f.read()).hexdigest()
except:
return hashlib.md5(open(model_path, "rb").read()).hexdigest()
@staticmethod
def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE):
if combine:
processed_wave = None
for segment_count, segment in enumerate(wave):
start = 0 if segment_count == 0 else margin_size
end = None if segment_count == len(wave) - 1 else -margin_size
if margin_size == 0:
end = None
part = segment[:, start:end]
processed_wave = part if processed_wave is None else np.concatenate((processed_wave, part), axis=-1)
else:
processed_wave = []
sample_count = wave.shape[-1]
if chunk_size <= 0 or chunk_size > sample_count:
chunk_size = sample_count
if margin_size > chunk_size:
margin_size = chunk_size
for segment_count, skip in enumerate(range(0, sample_count, chunk_size)):
margin = 0 if segment_count == 0 else margin_size
end = min(skip + chunk_size + margin_size, sample_count)
start = skip - margin
processed_wave.append(wave[:, start:end].copy())
if end == sample_count:
break
return processed_wave
def pad_wave(self, wave):
n_sample = wave.shape[1]
trim = self.model.n_fft // 2
gen_size = self.model.chunk_size - 2 * trim
pad = gen_size - n_sample % gen_size
wave_p = np.concatenate((np.zeros((2, trim)), wave, np.zeros((2, pad)), np.zeros((2, trim))), 1)
mix_waves = [torch.tensor(wave_p[:, i:i + self.model.chunk_size], dtype=torch.float32).to(self.device)
for i in range(0, n_sample + pad, gen_size)]
return torch.stack(mix_waves), pad, trim
def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
mix_waves = mix_waves.split(1)
with torch.no_grad():
pw = []
for mix_wave in mix_waves:
self.prog.update()
spec = self.model.stft(mix_wave)
processed_spec = torch.tensor(self.process(spec))
processed_wav = self.model.istft(processed_spec.to(self.device))
result = processed_wav[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).cpu().numpy()
pw.append(result)
q.put({_id: np.concatenate(pw, axis=-1)[:, :-pad]})
def process_wave(self, wave: np.array, mt_threads=1):
self.prog = tqdm(total=0)
chunk = wave.shape[-1] // mt_threads
waves = self.segment(wave, False, chunk)
q = queue.Queue()
threads = []
for c, batch in enumerate(waves):
mix_waves, pad, trim = self.pad_wave(batch)
self.prog.total = len(mix_waves) * mt_threads
thread = threading.Thread(target=self._process_wave, args=(mix_waves, trim, pad, q, c))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
self.prog.close()
processed_batches = [q.get() for _ in range(len(waves))]
processed_batches.sort(key=lambda d: list(d.keys())[0])
return self.segment([list(wave.values())[0] for wave in processed_batches], True, chunk)
|