Tune-Splitter / inference.py
CCockrum's picture
Update inference.py
9bd0a22 verified
# inference.py
import os
import gc
import json
import shlex
import sys
import torch
import librosa
import numpy as np
import subprocess
import soundfile as sf
import hashlib
import random
import time
import traceback
import onnxruntime as ort
from utils import logger, remove_directory_contents, create_directories
from mdx_core import MDX, MDXModel
from effects import add_vocal_effects, add_instrumental_effects
stem_naming = {
"Vocals": "Instrumental",
"Other": "Instruments",
"Instrumental": "Vocals",
"Drums": "Drumless",
"Bass": "Bassless",
}
def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False,
suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2, device_base="cuda"):
device = torch.device("cuda:0" if device_base == "cuda" else "cpu")
processor_num = 0 if device_base == "cuda" else -1
if device_base == "cuda":
vram_gb = torch.cuda.get_device_properties(device).total_memory / 1024**3
m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2)
logger.info(f"threads: {m_threads} vram: {vram_gb}")
else:
m_threads = 1
model_hash = MDX.get_hash(model_path)
mp = model_params.get(model_hash)
model = MDXModel(
device,
dim_f=mp["mdx_dim_f_set"],
dim_t=2 ** mp["mdx_dim_t_set"],
n_fft=mp["mdx_n_fft_scale_set"],
stem_name=mp["primary_stem"],
compensation=mp["compensate"],
)
mdx_sess = MDX(model_path, model, processor=processor_num)
wave, sr = librosa.load(filename, mono=False, sr=44100)
peak = max(np.max(wave), abs(np.min(wave)))
wave /= peak
if denoise:
wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
wave_processed *= 0.5
else:
wave_processed = mdx_sess.process_wave(wave, m_threads)
wave_processed *= peak
stem_name = model.stem_name if suffix is None else suffix
main_filepath = None
if not exclude_main:
main_filepath = os.path.join(
output_dir,
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
)
sf.write(main_filepath, wave_processed.T, sr)
invert_filepath = None
if not exclude_inversion:
diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
invert_filepath = os.path.join(
output_dir,
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
)
sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
if not keep_orig:
os.remove(filename)
del mdx_sess, wave_processed, wave
gc.collect()
torch.cuda.empty_cache()
return main_filepath, invert_filepath
def run_mdx_beta(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False,
suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=1, device_base=""):
duration = librosa.get_duration(filename=filename)
if duration >= 60 and duration <= 120:
m_threads = 8
elif duration > 120:
m_threads = 16
logger.info(f"threads: {m_threads}")
device = torch.device("cpu")
processor_num = -1
model_hash = MDX.get_hash(model_path)
mp = model_params.get(model_hash)
model = MDXModel(
device,
dim_f=mp["mdx_dim_f_set"],
dim_t=2 ** mp["mdx_dim_t_set"],
n_fft=mp["mdx_n_fft_scale_set"],
stem_name=mp["primary_stem"],
compensation=mp["compensate"],
)
mdx_sess = MDX(model_path, model, processor=processor_num)
wave, sr = librosa.load(filename, mono=False, sr=44100)
peak = max(np.max(wave), abs(np.min(wave)))
wave /= peak
if denoise:
wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
wave_processed *= 0.5
else:
wave_processed = mdx_sess.process_wave(wave, m_threads)
wave_processed *= peak
stem_name = model.stem_name if suffix is None else suffix
main_filepath = None
if not exclude_main:
main_filepath = os.path.join(
output_dir,
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
)
sf.write(main_filepath, wave_processed.T, sr)
invert_filepath = None
if not exclude_inversion:
diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
invert_filepath = os.path.join(
output_dir,
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
)
sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
if not keep_orig:
os.remove(filename)
del mdx_sess, wave_processed, wave
gc.collect()
torch.cuda.empty_cache()
return main_filepath, invert_filepath
def convert_to_stereo_and_wav(audio_path, output_dir):
wave, sr = librosa.load(audio_path, mono=False, sr=44100)
if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav":
stereo_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(audio_path))[0]}_stereo.wav")
command = shlex.split(f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}")
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
return stereo_path
return audio_path
def get_hash(filepath):
with open(filepath, 'rb') as f:
file_hash = hashlib.blake2b()
while chunk := f.read(8192):
file_hash.update(chunk)
return file_hash.hexdigest()[:18]
def random_sleep():
time.sleep(round(random.uniform(5.2, 7.9), 1))