RVC_Models / inference.py
ArianatorQualquer's picture
Upload inference.py
a8a84ee verified
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
import argparse
import time
import librosa
from tqdm import tqdm
import sys
import os
import glob
import torch
import numpy as np
import soundfile as sf
import torch.nn as nn
# Using the embedded version of Python can also correctly import the utils module.
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
from utils import demix_track, demix_track_demucs, get_model_from_config
import warnings
warnings.filterwarnings("ignore")
def run_folder(model, args, config, device, verbose=False):
start_time = time.time()
model.eval()
all_mixtures_path = glob.glob(args.input_folder + '/*.*')
all_mixtures_path.sort()
print('Total files found: {}'.format(len(all_mixtures_path)))
instruments = config.training.instruments
if config.training.target_instrument is not None:
instruments = [config.training.target_instrument]
if not os.path.isdir(args.store_dir):
os.mkdir(args.store_dir)
if not verbose:
all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress")
if args.disable_detailed_pbar:
detailed_pbar = False
else:
detailed_pbar = True
for path in all_mixtures_path:
print("Starting processing track: ", path)
if not verbose:
all_mixtures_path.set_postfix({'track': os.path.basename(path)})
try:
# mix, sr = sf.read(path)
mix, sr = librosa.load(path, sr=44100, mono=False)
except Exception as e:
print('Can read track: {}'.format(path))
print('Error message: {}'.format(str(e)))
continue
# Convert mono to stereo if needed
if len(mix.shape) == 1:
mix = np.stack([mix, mix], axis=0)
mix_orig = mix.copy()
if 'normalize' in config.inference:
if config.inference['normalize'] is True:
mono = mix.mean(0)
mean = mono.mean()
std = mono.std()
mix = (mix - mean) / std
mixture = torch.tensor(mix, dtype=torch.float32)
if args.model_type == 'htdemucs':
res = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar)
else:
res = demix_track(config, model, mixture, device, pbar=detailed_pbar)
for instr in instruments:
estimates = res[instr].T
if 'normalize' in config.inference:
if config.inference['normalize'] is True:
estimates = estimates * std + mean
file_name, _ = os.path.splitext(os.path.basename(path))
output_file = os.path.join(args.store_dir, f"{file_name}_{instr}.wav")
sf.write(output_file, estimates, sr, subtype = 'FLOAT')
# Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent)
if args.extract_instrumental:
file_name, _ = os.path.splitext(os.path.basename(path))
instrum_file_name = os.path.join(args.store_dir, f"{file_name}_instrumental.wav")
if 'vocals' in instruments:
estimates = res['vocals'].T
else:
estimates = res[instruments[0]].T
if 'normalize' in config.inference:
if config.inference['normalize'] is True:
estimates = estimates * std + mean
sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype = 'FLOAT')
time.sleep(1)
print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def proc_folder(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default='mdx23c',
help="One of bandit, bandit_v2, bs_roformer, htdemucs, mdx23c, mel_band_roformer, scnet, scnet_unofficial, segm_models, swin_upernet, torchseg")
parser.add_argument("--config_path", type=str, help="path to config file")
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights")
parser.add_argument("--input_folder", type=str, help="folder with mixtures to process")
parser.add_argument("--store_dir", default="", type=str, help="path to store results as wav file")
parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids')
parser.add_argument("--extract_instrumental", action='store_true', help="invert vocals to get instrumental if provided")
parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar")
parser.add_argument("--force_cpu", action = 'store_true', help = "Force the use of CPU even if CUDA is available")
if args is None:
args = parser.parse_args()
else:
args = parser.parse_args(args)
device = "cpu"
if args.force_cpu:
device = "cpu"
elif torch.cuda.is_available():
print('CUDA is available, use --force_cpu to disable it.')
device = "cuda"
device = f'cuda:{args.device_ids}' if type(args.device_ids) == int else f'cuda:{args.device_ids[0]}'
elif torch.backends.mps.is_available():
device = "mps"
print("Using device: ", device)
model_load_start_time = time.time()
torch.backends.cudnn.benchmark = True
model, config = get_model_from_config(args.model_type, args.config_path)
if args.start_check_point != '':
print('Start from checkpoint: {}'.format(args.start_check_point))
if args.model_type == 'htdemucs':
state_dict = torch.load(args.start_check_point, map_location = device, weights_only=False)
# Fix for htdemucs pretrained models
if 'state' in state_dict:
state_dict = state_dict['state']
else:
state_dict = torch.load(args.start_check_point, map_location = device, weights_only=True)
model.load_state_dict(state_dict)
print("Instruments: {}".format(config.training.instruments))
# in case multiple CUDA GPUs are used and --device_ids arg is passed
if type(args.device_ids) != int:
model = nn.DataParallel(model, device_ids = args.device_ids)
model = model.to(device)
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
run_folder(model, args, config, device, verbose=True)
if __name__ == "__main__":
proc_folder(None)