Spaces:
Sleeping
Sleeping
# inference.py | |
import os | |
import gc | |
import json | |
import shlex | |
import sys | |
import torch | |
import librosa | |
import numpy as np | |
import subprocess | |
import soundfile as sf | |
import hashlib | |
import random | |
import time | |
import traceback | |
import onnxruntime as ort | |
from utils import logger, remove_directory_contents, create_directories | |
from mdx_core import MDX, MDXModel | |
from effects import add_vocal_effects, add_instrumental_effects | |
stem_naming = { | |
"Vocals": "Instrumental", | |
"Other": "Instruments", | |
"Instrumental": "Vocals", | |
"Drums": "Drumless", | |
"Bass": "Bassless", | |
} | |
def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, | |
suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2, device_base="cuda"): | |
device = torch.device("cuda:0" if device_base == "cuda" else "cpu") | |
processor_num = 0 if device_base == "cuda" else -1 | |
if device_base == "cuda": | |
vram_gb = torch.cuda.get_device_properties(device).total_memory / 1024**3 | |
m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2) | |
logger.info(f"threads: {m_threads} vram: {vram_gb}") | |
else: | |
m_threads = 1 | |
model_hash = MDX.get_hash(model_path) | |
mp = model_params.get(model_hash) | |
model = MDXModel( | |
device, | |
dim_f=mp["mdx_dim_f_set"], | |
dim_t=2 ** mp["mdx_dim_t_set"], | |
n_fft=mp["mdx_n_fft_scale_set"], | |
stem_name=mp["primary_stem"], | |
compensation=mp["compensate"], | |
) | |
mdx_sess = MDX(model_path, model, processor=processor_num) | |
wave, sr = librosa.load(filename, mono=False, sr=44100) | |
peak = max(np.max(wave), abs(np.min(wave))) | |
wave /= peak | |
if denoise: | |
wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) | |
wave_processed *= 0.5 | |
else: | |
wave_processed = mdx_sess.process_wave(wave, m_threads) | |
wave_processed *= peak | |
stem_name = model.stem_name if suffix is None else suffix | |
main_filepath = None | |
if not exclude_main: | |
main_filepath = os.path.join( | |
output_dir, | |
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav", | |
) | |
sf.write(main_filepath, wave_processed.T, sr) | |
invert_filepath = None | |
if not exclude_inversion: | |
diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix | |
stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name | |
invert_filepath = os.path.join( | |
output_dir, | |
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav", | |
) | |
sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr) | |
if not keep_orig: | |
os.remove(filename) | |
del mdx_sess, wave_processed, wave | |
gc.collect() | |
torch.cuda.empty_cache() | |
return main_filepath, invert_filepath | |
def run_mdx_beta(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, | |
suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=1, device_base=""): | |
duration = librosa.get_duration(filename=filename) | |
if duration >= 60 and duration <= 120: | |
m_threads = 8 | |
elif duration > 120: | |
m_threads = 16 | |
logger.info(f"threads: {m_threads}") | |
device = torch.device("cpu") | |
processor_num = -1 | |
model_hash = MDX.get_hash(model_path) | |
mp = model_params.get(model_hash) | |
model = MDXModel( | |
device, | |
dim_f=mp["mdx_dim_f_set"], | |
dim_t=2 ** mp["mdx_dim_t_set"], | |
n_fft=mp["mdx_n_fft_scale_set"], | |
stem_name=mp["primary_stem"], | |
compensation=mp["compensate"], | |
) | |
mdx_sess = MDX(model_path, model, processor=processor_num) | |
wave, sr = librosa.load(filename, mono=False, sr=44100) | |
peak = max(np.max(wave), abs(np.min(wave))) | |
wave /= peak | |
if denoise: | |
wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) | |
wave_processed *= 0.5 | |
else: | |
wave_processed = mdx_sess.process_wave(wave, m_threads) | |
wave_processed *= peak | |
stem_name = model.stem_name if suffix is None else suffix | |
main_filepath = None | |
if not exclude_main: | |
main_filepath = os.path.join( | |
output_dir, | |
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav", | |
) | |
sf.write(main_filepath, wave_processed.T, sr) | |
invert_filepath = None | |
if not exclude_inversion: | |
diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix | |
stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name | |
invert_filepath = os.path.join( | |
output_dir, | |
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav", | |
) | |
sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr) | |
if not keep_orig: | |
os.remove(filename) | |
del mdx_sess, wave_processed, wave | |
gc.collect() | |
torch.cuda.empty_cache() | |
return main_filepath, invert_filepath | |
def convert_to_stereo_and_wav(audio_path, output_dir): | |
wave, sr = librosa.load(audio_path, mono=False, sr=44100) | |
if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav": | |
stereo_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_stereo.wav") | |
command = shlex.split(f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}") | |
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
return stereo_path | |
return audio_path | |
def get_hash(filepath): | |
with open(filepath, 'rb') as f: | |
file_hash = hashlib.blake2b() | |
while chunk := f.read(8192): | |
file_hash.update(chunk) | |
return file_hash.hexdigest()[:18] | |
def random_sleep(): | |
time.sleep(round(random.uniform(5.2, 7.9), 1)) | |